diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index a3222bf0..70ba13ef 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -57,7 +57,7 @@ jobs: ls -ltrh ls -ltrh dist - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.9.0 + uses: pypa/gh-action-pypi-publish@v1.10.2 with: repository-url: https://test.pypi.org/legacy/ verbose: true @@ -95,5 +95,5 @@ jobs: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.9.0 + - uses: pypa/gh-action-pypi-publish@v1.10.2 if: startsWith(github.ref, 'refs/tags') diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 244d6c92..7451bef5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,8 +74,9 @@ jobs: if: ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest'}} uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: - token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml flags: unittests name: codecov-umbrella diff --git a/README.md b/README.md index c4730fae..fe1576e8 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.LensSource( +sim = caustics.LensSource( lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 ) -plt.imshow(minisim(x, quad_level=3), origin="lower") +plt.imshow(sim(x), origin="lower") plt.axis("off") plt.show() ``` @@ -63,7 +63,7 @@ plt.show() newx = x.repeat(20, 1) newx += torch.normal(mean=0, std=0.1 * torch.ones_like(newx)) -images = torch.vmap(minisim)(newx) +images = torch.vmap(sim)(newx) fig, axarr = plt.subplots(4, 5, figsize=(20, 16)) for ax, im in zip(axarr.flatten(), images): @@ -76,7 +76,7 @@ plt.show() ### Automatic Differentiation ```python -J = torch.func.jacfwd(minisim)(x) +J = torch.func.jacfwd(sim)(x) # Plot the new images fig, axarr = plt.subplots(3, 7, figsize=(20, 9)) diff --git a/docs/requirements.txt b/docs/requirements.txt index 209d48f9..a6b3e0ae 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ +emcee ipywidgets jupyter-book matplotlib diff --git a/docs/source/_toc.yml b/docs/source/_toc.yml index d4a66389..9ad25f2e 100644 --- a/docs/source/_toc.yml +++ b/docs/source/_toc.yml @@ -21,11 +21,10 @@ chapters: - file: tutorials/InvertLensEquation - file: tutorials/Parameters - file: tutorials/Simulators - - file: tutorials/Playground - file: examples/index sections: - file: examples/Example_ImageFit_LM - - file: examples/Example_ImageFit_NUTS + - file: examples/Example_ImageFit_MCMC - file: examples/Example_QSOLensFit - file: contributing - file: frequently_asked_questions diff --git a/docs/source/examples/Example_ImageFit_LM.ipynb b/docs/source/examples/Example_ImageFit_LM.ipynb index 41f06858..d06aa40d 100644 --- a/docs/source/examples/Example_ImageFit_LM.ipynb +++ b/docs/source/examples/Example_ImageFit_LM.ipynb @@ -56,6 +56,7 @@ "cosmology.to(dtype=torch.float32)\n", "\n", "upsample_factor = 1\n", + "quad_level = 3\n", "thx, thy = caustics.utils.meshgrid(\n", " pixelscale / upsample_factor,\n", " upsample_factor * numPix,\n", @@ -117,6 +118,7 @@ " pixels_x=numPix,\n", " pixelscale=pixelscale,\n", " upsample_factor=upsample_factor,\n", + " quad_level=quad_level,\n", " z_s=2.0,\n", ")" ] @@ -128,16 +130,18 @@ "source": [ "## Sample some mock data\n", "\n", - "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n", - "\n", - "Note that when we sample the simulator we call it with `quad_level=7`. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed." + "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit." ] }, { "cell_type": "code", "execution_count": null, "id": "7", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Generate the mock data\n", @@ -178,7 +182,7 @@ "print(true_params)\n", "\n", "# simulate lens, crop extra evaluation for PSF\n", - "true_system = sim(allparams, quad_level=7) # simulate at high resolution\n", + "true_system = sim(allparams)\n", "\n", "fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n", "axarr[0].imshow(\n", @@ -231,7 +235,7 @@ "res = caustics.utils.batch_lm(\n", " batch_inits,\n", " obs_system.reshape(-1).repeat(10, 1),\n", - " lambda x: sim(x, quad_level=3).reshape(-1),\n", + " lambda x: sim(x).reshape(-1),\n", " C=variance.reshape(-1).repeat(10, 1),\n", ")\n", "best_fit = res[0][np.argmin(res[2].numpy())]\n", @@ -242,7 +246,11 @@ "cell_type": "code", "execution_count": null, "id": "10", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "print(best_fit, allparams)\n", @@ -288,7 +296,7 @@ "outputs": [], "source": [ "# Compute jacobian\n", - "J = torch.func.jacfwd(lambda x: sim(x, quad_level=3))(best_fit)\n", + "J = torch.func.jacfwd(lambda x: sim(x))(best_fit)\n", "fig, axarr = plt.subplots(3, 7, figsize=(21, 9))\n", "for i, ax in enumerate(axarr.flatten()):\n", " ax.imshow(J[..., i], origin=\"lower\")\n", @@ -310,7 +318,11 @@ "cell_type": "code", "execution_count": null, "id": "14", - "metadata": {}, + "metadata": { + "tags": [ + "hide-cell" + ] + }, "outputs": [], "source": [ "def corner_plot_covariance(\n", diff --git a/docs/source/examples/Example_ImageFit_NUTS.ipynb b/docs/source/examples/Example_ImageFit_MCMC.ipynb similarity index 55% rename from docs/source/examples/Example_ImageFit_NUTS.ipynb rename to docs/source/examples/Example_ImageFit_MCMC.ipynb index e53bcbdf..3a8a613a 100644 --- a/docs/source/examples/Example_ImageFit_NUTS.ipynb +++ b/docs/source/examples/Example_ImageFit_MCMC.ipynb @@ -5,9 +5,9 @@ "id": "0", "metadata": {}, "source": [ - "# Modelling a Lens image using a No U-Turn Sampler\n", + "# Modelling a Lens image using MCMC\n", "\n", - "In this hypothetical scenario we have an image of galaxy galaxy strong lensing and we would like to recover a model of this scene. Thus we will need to determine parameters for the background source light, the lensing galaxy light, and the lensing galaxy mass distribution. A common technique for analyzing strong lensing systems is a Markov Chain Monte-Carlo which can explore the parameter space and provide us with important metrics about the model and uncertainty on all parameters. Since caustics is differentiable we have access to especially efficient gradient based MCMC algorithms. A very convenient algorithm is the No U-Turn Sampler, or NUTS, which uses derivarives to efficiently explore the likelihood distribution by treating it like a potential that a point mass is exploring. The NUST version we use as implemented in the Pyro package has no tunable parameters, thus we can simply give it a start point and it will explore for as many iterations as we give it. What's more, NUTS is so efficient that very often the autocorrelation length for the samples is approximately 1, meaning that each sample is independent from all the others! This is especially handy in the complex non-linear space of strong lensing models." + "In this hypothetical scenario we have an image of galaxy galaxy strong lensing and we would like to recover a model of this scene. Thus we will need to determine parameters for the background source light, the lensing galaxy light, and the lensing galaxy mass distribution. A common technique for analyzing strong lensing systems is a Markov Chain Monte-Carlo which can explore the parameter space and provide us with important metrics about the model and uncertainty on all parameters. Since caustics is differentiable we have access to especially efficient gradient based MCMC algorithms. A very convenient algorithm is the No U-Turn Sampler, or NUTS, which uses derivatives to efficiently explore the likelihood distribution by treating it like a potential that a point mass is exploring. The NUTS version we use as implemented in the Pyro package has no tunable parameters, thus we can simply give it a start point and it will explore for as many iterations as we give it. What's more, NUTS is so efficient that very often the autocorrelation length for the samples is approximately 1, meaning that each sample is independent from all the others! This is especially handy in the complex non-linear space of strong lensing models." ] }, { @@ -21,13 +21,16 @@ "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", + "from matplotlib import colormaps\n", "from matplotlib.patches import Ellipse\n", "from scipy.stats import norm\n", + "from tqdm.notebook import tqdm\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import MCMC as pyro_MCMC\n", - "from pyro.infer import NUTS as pyro_NUTS" + "from pyro.infer import NUTS as pyro_NUTS\n", + "import emcee" ] }, { @@ -61,6 +64,7 @@ "cosmology.to(dtype=torch.float32)\n", "\n", "upsample_factor = 1\n", + "quad_level = 3\n", "thx, thy = caustics.utils.meshgrid(\n", " pixelscale / upsample_factor,\n", " upsample_factor * numPix,\n", @@ -133,16 +137,18 @@ "source": [ "## Sample some mock data\n", "\n", - "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.\n", - "\n", - "Note that when we sample the simulator we call it with quad_level=7. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed." + "Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit." ] }, { "cell_type": "code", "execution_count": null, "id": "7", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Generate the mock data\n", @@ -182,7 +188,7 @@ "print(true_params)\n", "\n", "# simulate lens, crop extra evaluation for PSF\n", - "true_system = sim(allparams, quad_level=7) # simulate at high resolution\n", + "true_system = sim(allparams) # simulate at high resolution\n", "\n", "fig, axarr = plt.subplots(1, 2, figsize=(15, 8))\n", "axarr[0].imshow(\n", @@ -226,7 +232,11 @@ "cell_type": "code", "execution_count": null, "id": "9", - "metadata": {}, + "metadata": { + "tags": [ + "hide-output" + ] + }, "outputs": [], "source": [ "def step(model, prior):\n", @@ -246,7 +256,7 @@ " \"jit_compile\": True,\n", " \"ignore_jit_warnings\": True,\n", " \"step_size\": 1e-3,\n", - " \"full_mass\": True,\n", + " \"full_mass\": False,\n", " \"adapt_step_size\": True,\n", " \"adapt_mass_matrix\": True,\n", "}\n", @@ -282,16 +292,31 @@ "metadata": {}, "outputs": [], "source": [ - "chain = mcmc.get_samples()[\"x\"]\n", - "chain = chain.numpy()\n", + "chain_nuts = mcmc.get_samples()[\"x\"]\n", + "chain_nuts = chain_nuts.numpy()\n", "\n", - "plt.plot(\n", - " range(len(chain)),\n", - " (chain - np.mean(chain, axis=0)) / np.std(chain, axis=0)\n", - " + 5 * np.arange(len(allparams)),\n", - ")\n", + "normed_chains = (chain_nuts - np.mean(chain_nuts, axis=0)) / np.std(chain_nuts, axis=0)\n", + "for i in range(chain_nuts.shape[1]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_nuts.shape[1]))\n", "plt.title(\"Chain for each parameter\")\n", - "plt.show()" + "plt.show()\n", + "\n", + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(\n", + " emcee.autocorr.integrated_time(\n", + " chain_nuts, has_walkers=False, tol=10, quiet=True\n", + " )\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6ebd65b", + "metadata": {}, + "source": [ + "As is common for NUTS sampling, the average autocorrelation time for the parameters is around 1, meaning that essentially every sample is independent. This is unlike a Metropolis-Hastings sampler like the one used in emcee (see the other example tutorial)." ] }, { @@ -308,7 +333,11 @@ "cell_type": "code", "execution_count": null, "id": "13", - "metadata": {}, + "metadata": { + "tags": [ + "hide-cell" + ] + }, "outputs": [], "source": [ "def corner_plot(\n", @@ -409,20 +438,336 @@ "metadata": {}, "outputs": [], "source": [ - "fig = corner_plot(chain, true_values=allparams.numpy())\n", + "fig = corner_plot(chain_nuts, true_values=allparams.numpy())\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "bd256dff", + "metadata": {}, + "source": [ + "## Fit using emcee\n", + "\n", + "We now model the data using emcee which handles standard Metropolis-Hastings MCMC sampling (plus a few tricks). First we need to construct a log likelihood function. In our case this is just the squared residuals, divided by the variance in each pixel. The rest is specific to the emcee implementation." + ] + }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], + "source": [ + "vsim = torch.vmap(sim)\n", + "\n", + "\n", + "def density(x):\n", + " # Log-likelihood function\n", + " model = vsim(torch.as_tensor(x))\n", + " log_likelihood_value = -0.5 * torch.sum(\n", + " ((model - obs_system) ** 2) / variance, dim=(1, 2)\n", + " )\n", + " log_likelihood_value = torch.nan_to_num(log_likelihood_value, nan=-np.inf)\n", + " return log_likelihood_value.numpy()\n", + "\n", + "\n", + "nwalkers = 64\n", + "ndim = len(allparams)\n", + "\n", + "sampler = emcee.EnsembleSampler(nwalkers, ndim, density, vectorize=True)\n", + "\n", + "x0 = allparams + 0.01 * torch.randn(nwalkers, ndim)\n", + "print(\"burn-in\")\n", + "state = sampler.run_mcmc(x0, 100, skip_initial_state_check=True) # burn-in\n", + "sampler.reset()\n", + "print(\"production\")\n", + "state = sampler.run_mcmc(state, 1000) # production" + ] + }, + { + "cell_type": "markdown", + "id": "d9ed9291", + "metadata": {}, + "source": [ + "We have taken 64000 samples in this demo, in general you would want many more (each chain needs to run longer than 1000 steps). Its always a good idea to plot the chains and check that they look uncorrelated. Here we can see the non zero autocorrelation length for one of the chains even over 1000 steps. This indicates we should run the chains much longer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d87abb8", + "metadata": {}, + "outputs": [], + "source": [ + "chain_mh = sampler.get_chain()\n", + "normed_chains = (chain_mh[:, 0] - np.mean(chain_mh[:, 0], axis=0)) / np.std(\n", + " chain_mh[:, 0], axis=0\n", + ")\n", + "for i in range(chain_mh.shape[2]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_mh.shape[2]))\n", + "plt.title(\"Chain for each parameter\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "480e849f", + "metadata": {}, + "source": [ + "Since the autocorrelation length is >1, we can compute an effective sample size to determine how many equivalent independent points we have drawn. As the warning suggests, in this demo we cannot compute the actual autocorrelation length, the autocorrelation length increases as we draw more samples (you can test this by changing the 1000 above to a larger number). Assuming that the autocorrelation is actually of a similar length to the chain (1000), this means we have drawn approximately 64 independent samples (one for each walker) which is similar to the NUTS example, though for NUTS we knew each sample was independent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5b0924f", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(emcee.autocorr.integrated_time(chain_mh, quiet=True)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a51c577b", + "metadata": {}, + "source": [ + "Similar to the NUTS example, we may plot the samples in a corner plot. However, we thin the samples first so that the number of points is not overwhelming to plot. As you can see in the subfigures there is still a bloby structure of the samples, suggesting that the chains were not run long enough to converge." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afae21ff", + "metadata": {}, + "outputs": [], + "source": [ + "N = chain_mh.shape[0] * chain_mh.shape[1]\n", + "fig = corner_plot(\n", + " np.concatenate(chain_mh, axis=0)[:: int(N / 200)],\n", + " true_values=allparams.numpy(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1a4a2be6", + "metadata": {}, + "source": [ + "# Fit with MALA sampling\n", + "\n", + "Metropolis Adjusted Langevin Algorithm sampling is the half way point between NUTS and MH, it uses gradient information to make an efficient proposal distribution for a MH step. We have written a basic implementation below for demo purposes. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0a6de8f", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def mala_sampler(\n", + " initial_state, log_prob, log_prob_grad, num_samples, epsilon, mass_matrix\n", + "):\n", + " \"\"\"Metropolis Adjusted Langevin Algorithm (MALA) sampler with batch dimension.\n", + "\n", + " Args:\n", + " - initial_state (numpy array): Initial states of the chains, shape (num_chains, dim).\n", + " - log_prob (function): Function to compute the log probabilities of the current states.\n", + " - log_prob_grad (function): Function to compute the gradients of the log probabilities.\n", + " - num_samples (int): Number of samples to generate.\n", + " - epsilon (float): Step size for the Langevin dynamics.\n", + " - mass_matrix (numpy array): Mass matrix, shape (dim, dim), used to scale the dynamics.\n", + "\n", + "\n", + " Returns:\n", + " - samples (numpy array): Array of sampled values, shape (num_samples, num_chains, dim).\n", + " \"\"\"\n", + " num_chains, dim = initial_state.shape\n", + " samples = np.zeros((num_samples, num_chains, dim))\n", + " x_current = np.array(initial_state)\n", + " current_log_prob = log_prob(x_current)\n", + " inv_mass_matrix = np.linalg.inv(mass_matrix)\n", + " chol_inv_mass_matrix = np.linalg.cholesky(inv_mass_matrix)\n", + "\n", + " pbar = tqdm(range(num_samples))\n", + " acceptance_rate = np.zeros([0])\n", + " for i in pbar:\n", + " gradients = log_prob_grad(x_current)\n", + " noise = np.dot(np.random.randn(num_chains, dim), chol_inv_mass_matrix.T)\n", + " proposal = (\n", + " x_current\n", + " + 0.5 * epsilon**2 * np.dot(gradients, inv_mass_matrix)\n", + " + epsilon * noise\n", + " )\n", + "\n", + " # proposal = x_current + 0.5 * epsilon**2 * gradients + epsilon * np.random.randn(num_chains, *dim)\n", + " proposal_log_prob = log_prob(proposal)\n", + " # Metropolis-Hastings acceptance criterion, computed for each chain\n", + " acceptance_log_prob = proposal_log_prob - current_log_prob\n", + " accept = np.log(np.random.rand(num_chains)) < acceptance_log_prob\n", + " acceptance_rate = np.concatenate([acceptance_rate, accept])\n", + " pbar.set_description(f\"Acceptance rate: {acceptance_rate.mean():.2f}\")\n", + "\n", + " # Update states where accepted\n", + " x_current[accept] = proposal[accept]\n", + " current_log_prob[accept] = proposal_log_prob[accept]\n", + "\n", + " samples[i] = x_current\n", + "\n", + " return samples" + ] + }, + { + "cell_type": "markdown", + "id": "e934cbb9", + "metadata": {}, + "source": [ + "Here we run the MALA sampler after a small burn-in. We cheat a little bit and use the previous sampler to construct a mass matrix, this makes MALA more efficient but you could just as easily set the mass matrix to identity for the burn-in then use the burn-in samples to get a mediocre mass matrix, it only requires more fiddling with parameters (epsilon)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52c53b76", + "metadata": {}, + "outputs": [], + "source": [ + "def density_grad(x):\n", + " x = torch.as_tensor(x, dtype=torch.float32)\n", + " x.requires_grad = True\n", + " model = vsim(x)\n", + " log_likelihood_value = -0.5 * torch.sum(\n", + " ((model - obs_system) ** 2) / variance, dim=(1, 2)\n", + " )\n", + " log_likelihood_value = torch.nan_to_num(log_likelihood_value, nan=-np.inf)\n", + " log_likelihood_value.sum().backward()\n", + " return x.grad.numpy()\n", + "\n", + "\n", + "nwalkers = 32\n", + "x0 = allparams + 0.01 * torch.randn(nwalkers, ndim)\n", + "mass_matrix = np.linalg.inv(np.cov(chain_mh.reshape(-1, ndim), rowvar=False))\n", + "\n", + "chain_burnin_mala = mala_sampler(\n", + " initial_state=x0,\n", + " log_prob=density,\n", + " log_prob_grad=density_grad,\n", + " num_samples=100,\n", + " epsilon=3e-1,\n", + " mass_matrix=mass_matrix,\n", + ") # burn-in\n", + "\n", + "chain_mala = mala_sampler(\n", + " initial_state=chain_burnin_mala[-1],\n", + " log_prob=density,\n", + " log_prob_grad=density_grad,\n", + " num_samples=1000,\n", + " epsilon=3e-1,\n", + " mass_matrix=mass_matrix,\n", + ") # production" + ] + }, + { + "cell_type": "markdown", + "id": "e483c657", + "metadata": {}, + "source": [ + "PLotting the chains we see that they mix much better than the MH sampler, but less than NUTS, as would be expected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fca7a07", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "normed_chains = (chain_mala[:, 0] - np.mean(chain_mala[:, 0], axis=0)) / np.std(\n", + " chain_mala[:, 0], axis=0\n", + ")\n", + "for i in range(chain_mala.shape[2]):\n", + " plt.plot(normed_chains[:, i], color=colormaps[\"viridis\"](i / chain_mala.shape[2]))\n", + "plt.title(\"Chain for each parameter\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "57c8399f", + "metadata": {}, + "source": [ + "The autocorrelation length is better than MH, but worse than NUTS as expected. Again the effective sample size can't be trusted and is probably comparable to our NUTS example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d058064a", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " \"Autocorrelation time: \",\n", + " np.mean(emcee.autocorr.integrated_time(chain_mala, quiet=True)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "30c45f4f", + "metadata": {}, + "source": [ + "The corner plot much like the MH example shows that more samples are needed given the visible blobs and gaps in the sampling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e97dead", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "N = chain_mh.shape[0] * chain_mh.shape[1]\n", + "fig = corner_plot(\n", + " np.concatenate(chain_mh, axis=0)[:: int(N / 200)],\n", + " true_values=allparams.numpy(),\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f84bcb90", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -432,7 +777,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/examples/Example_QSOLensFit.ipynb b/docs/source/examples/Example_QSOLensFit.ipynb index dff248e0..c9febffa 100644 --- a/docs/source/examples/Example_QSOLensFit.ipynb +++ b/docs/source/examples/Example_QSOLensFit.ipynb @@ -106,11 +106,22 @@ " dtype=torch.float32,\n", ")\n", "\n", - "fig, ax = plt.subplots()\n", - "\n", "A = lens.jacobian_lens_equation(thx, thy, z_s, params)\n", - "detA = torch.linalg.det(A)\n", - "\n", + "detA = torch.linalg.det(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2263184", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", "# Get the path from the matplotlib contour plot of the critical line\n", "paths = CS.allsegs[0]\n", @@ -190,7 +201,11 @@ "cell_type": "code", "execution_count": null, "id": "11", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", diff --git a/docs/source/interfaceindex.rst b/docs/source/interfaceindex.rst index 64c1eb58..10551691 100644 --- a/docs/source/interfaceindex.rst +++ b/docs/source/interfaceindex.rst @@ -5,6 +5,6 @@ There are three user interfaces to Caustics: configuration file, object-oriented Here we have built the same tutorial three times over so you can see how the three interfaces compare. They provide progressively more freedom and power, but also require more knowledge of the underlying system. -1. Configuration File: :doc:`tutorials/BasicIntroduction_yaml` -2. Object-Oriented: :doc:`tutorials/BasicIntroduction_oop` -3. Functional: :doc:`tutorials/BasicIntroduction_func` +1. Configuration File: :doc:`tutorials/InterfaceIntroduction_yaml` +2. Object-Oriented: :doc:`tutorials/InterfaceIntroduction_oop` +3. Functional: :doc:`tutorials/InterfaceIntroduction_func` diff --git a/docs/source/intro.md b/docs/source/intro.md index 2a438810..33cff80d 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -1,5 +1,7 @@ # Welcome to Caustics’ documentation! +![Logo GIF](../../media/caustics_logo.gif) + The lensing pipeline of the future: GPU-accelerated, automatically-differentiable, highly modular and extensible. @@ -39,10 +41,10 @@ x = torch.tensor([ 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 ]) # fmt: skip -minisim = caustics.LensSource( - lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 +sim = caustics.LensSource( + lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3 ) -plt.imshow(minisim(x, quad_level=3), origin="lower") +plt.imshow(sim(x), origin="lower") plt.show() ``` @@ -54,7 +56,7 @@ plt.show() newx = x.repeat(20, 1) newx += torch.normal(mean=0, std=0.1 * torch.ones_like(newx)) -images = torch.vmap(minisim)(newx) +images = torch.vmap(sim)(newx) fig, axarr = plt.subplots(4, 5, figsize=(20, 16)) for ax, im in zip(axarr.flatten(), images): @@ -67,7 +69,7 @@ plt.show() ### Automatic Differentiation ```python -J = torch.func.jacfwd(minisim)(x) +J = torch.func.jacfwd(sim)(x) # Plot the new images fig, axarr = plt.subplots(3, 7, figsize=(20, 9)) diff --git a/docs/source/tutorials/InterfaceIntroduction_func.ipynb b/docs/source/tutorials/InterfaceIntroduction_func.ipynb index 8c90c395..ffae44c3 100644 --- a/docs/source/tutorials/InterfaceIntroduction_func.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_func.ipynb @@ -196,8 +196,19 @@ "metadata": {}, "outputs": [], "source": [ - "J = torch.func.jacfwd(sim)(x) # Substitute minisim with sim for the yaml method\n", - "\n", + "J = torch.func.jacfwd(sim)(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = [\n", " \"z_s\",\n", diff --git a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb index aaca80f7..dbb251a4 100644 --- a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb @@ -129,7 +129,7 @@ "outputs": [], "source": [ "sim = caustics.LensSource(\n", - " lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100\n", + " lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100, quad_level=3\n", ")" ] }, @@ -174,7 +174,7 @@ "outputs": [], "source": [ "# Substitute minisim with sim for yaml method\n", - "image = sim(x, quad_level=3).detach().cpu().numpy()\n", + "image = sim(x).detach().cpu().numpy()\n", "\n", "plt.imshow(image, origin=\"lower\")\n", "plt.axis(\"off\")\n", @@ -226,8 +226,19 @@ "source": [ "# Now lets compute the jacobian of the simulator wrt each parameter\n", "J = torch.func.jacfwd(sim)(x)\n", - "# The shape of J is (npixels y, npixels x, nparameters)\n", - "\n", + "# The shape of J is (npixels y, npixels x, nparameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = tuple(sim.state_dict().keys())[3:]\n", "for i, ax in enumerate(axarr.flatten()):\n", diff --git a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb index 6eef1777..d262ea12 100644 --- a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb @@ -100,7 +100,7 @@ "source": [ "### Plot the Results!\n", "\n", - "This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. There are several possible arguments to the simulator function, here we use `quad_level` which indicates the depth of quadrature integration to use for each pixel (how accurate to make the flux)." + "This section is mostly self explanatory. We evaluate the simulator configuration by calling it like a function. " ] }, { @@ -110,7 +110,7 @@ "outputs": [], "source": [ "# Here we sample an image\n", - "image = sim(x, quad_level=3).detach().cpu().numpy()\n", + "image = sim(x).detach().cpu().numpy()\n", "\n", "plt.imshow(image, origin=\"lower\")\n", "plt.axis(\"off\")\n", @@ -163,8 +163,19 @@ "source": [ "# Now lets compute the jacobian of the simulator wrt each parameter\n", "J = torch.func.jacfwd(sim)(x)\n", - "# The shape of J is (npixels y, npixels x, nparameters)\n", - "\n", + "# The shape of J is (npixels y, npixels x, nparameters)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", "labels = tuple(sim.state_dict().keys())[3:]\n", "for i, ax in enumerate(axarr.flatten()):\n", diff --git a/docs/source/tutorials/InvertLensEquation.ipynb b/docs/source/tutorials/InvertLensEquation.ipynb index 77720582..edd6dca9 100644 --- a/docs/source/tutorials/InvertLensEquation.ipynb +++ b/docs/source/tutorials/InvertLensEquation.ipynb @@ -7,7 +7,7 @@ "source": [ "# Inverting the Lens Equation\n", "\n", - "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a little luck to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" + "The lens equation $\\vec{\\beta} = \\vec{\\theta} - \\vec{\\alpha}(\\vec{\\theta})$ allows us to find a point in the source plane given a point in the image plane. However, sometimes we know a point in the source plane and would like to see where it ends up in the image plane. This is not easy to do since a point in the source plane may map to multiple locations in the image plane. There is no closed form function to invert the lens equation, in large part because the deflection angle $\\vec{\\alpha}$ depends on the position in the image plane $\\vec{\\theta}$. To invert the lens equation, we will need to rely on optimization and a iterative procedures to find all the images for a given source plane point. Below we will demonstrate how this is done in caustic!" ] }, { @@ -23,6 +23,8 @@ "\n", "import torch\n", "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Polygon\n", + "from matplotlib.collections import PatchCollection\n", "import numpy as np\n", "\n", "import caustics" @@ -64,6 +66,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "2163ed78", + "metadata": {}, + "source": [ + "Here we run the forward raytracing for our particular lens model. In caustics we provide a convenient `forward_raytrace` function which can be called for any lens model. Internally, this constructs a number of triangles in the image plane, raytraces them to the source plane and identifies which ones contain the desired source plane position. Iteratively subdividing the triangles eventually converges on image plane positions which map to the desired source plane position. See further down for more detail." + ] + }, { "cell_type": "code", "execution_count": null, @@ -82,11 +92,23 @@ "bx, by = lens.raytrace(x, y, z_s)" ] }, + { + "cell_type": "markdown", + "id": "462b2e8f", + "metadata": {}, + "source": [ + "When we raytrace the coordinates we get out from `forward_raytrace` it is not too surprising that they all give source plane positions very close to the desired source plane position. Here we plot them so you can see:" + ] + }, { "cell_type": "code", "execution_count": null, "id": "4", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", @@ -110,20 +132,581 @@ "ax.scatter(x, y, color=\"b\", label=\"forward raytrace\", zorder=10)\n", "ax.scatter(bx, by, color=\"r\", marker=\"x\", label=\"source plane\", zorder=9)\n", "ax.scatter([sp_x.item()], [sp_y.item()], color=\"g\", label=\"true pos\", zorder=8)\n", + "ax.set_axis_off()\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6641535a", + "metadata": {}, + "source": [ + "It is also often not necessary to model the central demagnified region since it is so faint (approximately a 100,000 times fainter in this case) that it doesn't contribute measurably to the flux of an image. We can very easily check the magnification of every point and remove the unnecessary one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5d20df7", + "metadata": {}, + "outputs": [], + "source": [ + "m = lens.magnification(x, y, z_s)\n", + "print(m.detach().cpu().tolist())\n", + "N_m = torch.argsort(m)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "# Get the path from the matplotlib contour plot of the critical line\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "\n", + "plt.scatter(x[N_m[1:]], y[N_m[1:]], color=\"b\", label=\"magnified\")\n", + "plt.scatter(x[N_m[0]], y[N_m[0]], color=\"r\", label=\"de-magnified\")\n", + "plt.axis(\"off\")\n", "plt.legend()\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "0163d5f7", + "metadata": {}, + "source": [ + "## Lets take a look\n", + "\n", + "Using the `LensSource` simulator and the forward raytracing coordinates we can focus our calculations on the regions of interest for each image. Note however that the regions can overlap, which they do very slightly in this case." + ] + }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], + "source": [ + "src = caustics.Sersic(\n", + " x0=0.2, y0=0.2, q=0.9, phi=0.0, n=1.0, Re=0.05, Ie=1.0, name=\"source\"\n", + ")\n", + "\n", + "sim = caustics.LensSource(\n", + " lens=lens, source=src, z_s=z_s, x0=None, y0=None, pixelscale=0.005, pixels_x=100\n", + ")\n", + "\n", + "# Plot the source and lens\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "# Get the path from the matplotlib contour plot of the critical line\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "for i in range(len(x)):\n", + " ax.imshow(\n", + " sim([x[i], y[i]]),\n", + " extent=(\n", + " -sim.pixelscale * sim.pixels_x / 2 + x[i],\n", + " sim.pixelscale * sim.pixels_x / 2 + x[i],\n", + " -sim.pixelscale * sim.pixels_y / 2 + y[i],\n", + " sim.pixelscale * sim.pixels_y / 2 + y[i],\n", + " ),\n", + " origin=\"lower\",\n", + " )\n", + "ax.set_xlim([-1.5, 2])\n", + "ax.set_ylim([-1.5, 2])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "edb5fa5e", + "metadata": {}, + "source": [ + "This is much more efficient than evaluating a whole image. Below you can see the same setup but we see how the simulator spends a lot of pixels evaluating low flux areas that don't matter much for modelling." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0689864", + "metadata": {}, + "outputs": [], + "source": [ + "sim_wide = caustics.LensSource(\n", + " lens=lens, source=src, z_s=z_s, pixelscale=0.005, pixels_x=1000\n", + ")\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "ax.imshow(\n", + " sim_wide({}),\n", + " origin=\"lower\",\n", + " extent=(\n", + " -sim_wide.pixelscale * sim_wide.pixels_x / 2,\n", + " sim_wide.pixelscale * sim_wide.pixels_x / 2,\n", + " -sim_wide.pixelscale * sim_wide.pixels_y / 2,\n", + " sim_wide.pixelscale * sim_wide.pixels_y / 2,\n", + " ),\n", + ")\n", + "ax.set_xlim([-1.5, 2])\n", + "ax.set_ylim([-1.5, 2])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "7a4c8fd5", + "metadata": {}, + "source": [ + "## How forward_raytrace works\n", + "\n", + "All forward raytracing methods are imperfect as they involve iterative solutions which require enough resolution to pick out all the relevant image plane positions. To start, lets consider a more naive algorithm, simply placing random points in the image plane, then running a root-finding algorithm to get the source plane positions to line up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e06ac42", + "metadata": {}, + "outputs": [], + "source": [ + "Ninit = 100\n", + "x_init = (torch.rand(Ninit) - 0.5) * fov\n", + "y_init = (torch.rand(Ninit) - 0.5) * fov\n", + "\n", + "\n", + "def raytrace(x, y):\n", + " return lens.raytrace(x, y, z_s)\n", + "\n", + "\n", + "final = caustics.lenses.func.forward_raytrace_rootfind(\n", + " x_init, y_init, sp_x, sp_y, raytrace\n", + ")\n", + "x_final, y_final = final[..., 0], final[..., 1]\n", + "\n", + "# Pick only points that converged\n", + "bx_final, by_final = raytrace(x_final, y_final)\n", + "R = torch.sqrt((sp_x - bx_final) ** 2 + (sp_y - by_final) ** 2)\n", + "x_final = x_final[R < 1e-3]\n", + "y_final = y_final[R < 1e-3]" + ] + }, + { + "cell_type": "markdown", + "id": "2abb217e", + "metadata": {}, + "source": [ + "Here we easily find the four magnified images, but the central demagnified image is (often) not found by this method since a point has to get lucky enough to start very close to the correct position in order for the gradient based root finder to work." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2fe0e6f", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "colors = [\"tab:red\", \"tab:blue\", \"tab:green\", \"tab:orange\", \"tab:purple\"]\n", + "for c in colors:\n", + " if x_final.shape[0] == 0:\n", + " break\n", + " R = ((x_final[0] - x_final) ** 2 + (y_final[0] - y_final) ** 2).sqrt()\n", + " ax.scatter(x_init[R < 0.1], y_init[R < 0.1], color=c)\n", + " ax.scatter(x_final[0], y_final[0], color=\"k\", s=200, marker=\"*\")\n", + " ax.scatter(x_final[0], y_final[0], color=c, s=100, marker=\"*\")\n", + " x_init = x_init[R >= 0.1]\n", + " y_init = y_init[R >= 0.1]\n", + " x_final = x_final[R >= 0.1]\n", + " y_final = y_final[R >= 0.1]\n", + "ax.axes.set_axis_off()\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b4c5d47b", + "metadata": {}, + "source": [ + "Let's now look at a more clever algorithm. We will map triangles in the image plane to triangles in the source plane, we may then explore recursively, any triangles which enclose the desired source point. Due to the non-linearity of the gravitational lensing transformation, we will also search the neighbor of any triangle that seems to have found an image position. First we highlight in green, any triangles which contain the source point, then expand to all their neighbors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5677ef22", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "n = 10\n", + "s = torch.stack((sp_x, sp_y))\n", + "# Construct a tiling of the image plane (squares at this point)\n", + "X, Y = torch.meshgrid(\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " torch.linspace(-fov / 2, fov / 2, n),\n", + " indexing=\"ij\",\n", + ")\n", + "E1 = torch.stack((X, Y), dim=-1)\n", + "# build the upper and lower triangles within the squares of the grid\n", + "E1 = torch.cat(\n", + " (\n", + " torch.stack((E1[:-1, :-1], E1[:-1, 1:], E1[1:, 1:]), dim=-2),\n", + " torch.stack((E1[:-1, :-1], E1[1:, :-1], E1[1:, 1:]), dim=-2),\n", + " ),\n", + " dim=0,\n", + ").reshape(-1, 3, 2)\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E1[..., 0], E1[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate1 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E1, locate1):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=1,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E2 = E1[locate1]\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "E2 = caustics.lenses.func.remove_triangle_duplicates(E2)\n", + "# Upsample the triangles\n", + "E2 = torch.vmap(caustics.lenses.func.triangle_upsample)(E2)\n", + "E2 = E2.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E2[..., 0], E2[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate2 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E2, locate2):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "61fdd482", + "metadata": {}, + "source": [ + "The process repeats until the triangles have converged to a very small area, at which point we then run a root finding algorithm to get the final points. The central region is a very unstable optimum, so we need to use the triangle method for several iterations before we can run the root finder to get the exact optimal point." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ef54c41", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "# Get all the neighbors and upsample the triangles\n", + "E3 = E2[locate2]\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "E3 = caustics.lenses.func.remove_triangle_duplicates(E3)\n", + "# Upsample the triangles\n", + "E3 = torch.vmap(caustics.lenses.func.triangle_upsample)(E3)\n", + "E3 = E3.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E3[..., 0], E3[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate3 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E3, locate3):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E4 = E3[locate3]\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "E4 = caustics.lenses.func.remove_triangle_duplicates(E4)\n", + "# Upsample the triangles\n", + "E4 = torch.vmap(caustics.lenses.func.triangle_upsample)(E4)\n", + "E4 = E4.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E4[..., 0], E4[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate4 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E4, locate4):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E5 = E4[locate4]\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "E5 = caustics.lenses.func.remove_triangle_duplicates(E5)\n", + "# Upsample the triangles\n", + "E5 = torch.vmap(caustics.lenses.func.triangle_upsample)(E5)\n", + "E5 = E5.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E5[..., 0], E5[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate5 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E5, locate5):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "# Get all the neighbors and upsample the triangles\n", + "E6 = E5[locate5]\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_neighbors)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "E6 = caustics.lenses.func.remove_triangle_duplicates(E6)\n", + "# Upsample the triangles\n", + "E6 = torch.vmap(caustics.lenses.func.triangle_upsample)(E6)\n", + "E6 = E6.reshape(-1, 3, 2)\n", + "\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "S = raytrace(E6[..., 0], E6[..., 1])\n", + "S = torch.stack(S, dim=-1)\n", + "\n", + "# Identify triangles that contain the source plane point\n", + "locate6 = torch.vmap(caustics.lenses.func.triangle_contains, in_dims=(0, None))(S, s)\n", + "patches = []\n", + "for e, loc in zip(E6, locate6):\n", + " patches.append(\n", + " Polygon(\n", + " e,\n", + " fill=loc,\n", + " alpha=0.4 if loc else 1,\n", + " color=\"tab:green\" if loc else \"k\",\n", + " linewidth=0.5,\n", + " )\n", + " )\n", + "p = PatchCollection(patches, match_original=True)\n", + "ax.add_collection(p)\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()\n", + "\n", + "\n", + "# Run the root finding algorithm\n", + "E7 = E6[locate6].sum(dim=1) / 3\n", + "E7 = caustics.lenses.func.forward_raytrace_rootfind(\n", + " E7[..., 0], E7[..., 1], s[0], s[1], raytrace\n", + ")\n", + "fig, ax = plt.subplots()\n", + "CS = ax.contour(thx, thy, detA, levels=[0.0], colors=\"b\", zorder=1)\n", + "for path in paths:\n", + " # Collect the path into a discrete set of points\n", + " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", + " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", + " # raytrace the points to the source plane\n", + " y1, y2 = lens.raytrace(x1, x2, z_s)\n", + "\n", + " # Plot the caustic\n", + " ax.plot(y1, y2, color=\"r\", zorder=1)\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"k\", s=100, marker=\"*\")\n", + "ax.scatter(E7[..., 0], E7[..., 1], color=\"tab:green\", s=50, marker=\"*\")\n", + "ax.set_xlim([-fov / 1.9, fov / 1.9])\n", + "ax.set_ylim([-fov / 1.9, fov / 1.9])\n", + "ax.set_axis_off()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e421807b", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -133,7 +716,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/LensZoo.ipynb b/docs/source/tutorials/LensZoo.ipynb index 53e5fcd7..51d9cdc5 100644 --- a/docs/source/tutorials/LensZoo.ipynb +++ b/docs/source/tutorials/LensZoo.ipynb @@ -94,7 +94,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc5596c8", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", "axarr[0].axis(\"off\")\n", @@ -135,7 +148,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06143aa", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -181,7 +207,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87b924bb", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -228,7 +267,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56818b62", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -275,7 +327,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "102eac0c", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -325,7 +390,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "984b8ece", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -375,7 +453,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfe89486", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -420,7 +511,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6964846c", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "# convergence = avg_pool2d(lens.convergence(thx, thy, z_s, z_l).squeeze()[None, None], upsample_factor).squeeze()\n", "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", @@ -462,7 +566,20 @@ " pixels_x=n_pix,\n", " z_s=z_s,\n", " upsample_factor=2,\n", - ")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", @@ -475,14 +592,6 @@ "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/source/tutorials/MultiplaneDemo.ipynb b/docs/source/tutorials/MultiplaneDemo.ipynb index 6b7dff66..4f3d4296 100644 --- a/docs/source/tutorials/MultiplaneDemo.ipynb +++ b/docs/source/tutorials/MultiplaneDemo.ipynb @@ -100,14 +100,25 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "14291ce3", "metadata": {}, "outputs": [], "source": [ "# Effective reduced deflection angles for the multiplane lens system\n", - "ax, ay = lens.effective_reduced_deflection_angle(thx, thy, z_s)\n", - "\n", - "# Plot\n", + "ax, ay = lens.effective_reduced_deflection_angle(thx, thy, z_s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(12, 5))\n", "im = axarr[0].imshow(\n", " ax, extent=(thx[0][0], thx[0][-1], thy[0][0], thy[-1][0]), origin=\"lower\"\n", @@ -141,8 +152,20 @@ "A = lens.jacobian_lens_equation(thx, thy, z_s)\n", "\n", "# Here we compute detA at every point\n", - "detA = torch.linalg.det(A)\n", - "\n", + "detA = torch.linalg.det(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57bb5d70", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ "# Plot the critical line\n", "im = plt.imshow(\n", " np.log10(np.abs(detA.detach().cpu().numpy())),\n", @@ -159,7 +182,11 @@ "cell_type": "code", "execution_count": null, "id": "8", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# For completeness, here are the caustics!\n", diff --git a/docs/source/tutorials/Playground.ipynb b/docs/source/tutorials/Playground.ipynb deleted file mode 100644 index 482ef4f2..00000000 --- a/docs/source/tutorials/Playground.ipynb +++ /dev/null @@ -1,230 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0", - "metadata": {}, - "source": [ - "# Lensing playground\n", - "\n", - "This is just a fun notebook where you can interactively change lensing parameters. It is a great way to build some intuition around lensing and make some cool pictures!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "from torch.nn.functional import avg_pool2d\n", - "import matplotlib.pyplot as plt\n", - "from ipywidgets import interact\n", - "import numpy as np\n", - "\n", - "import caustics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": {}, - "outputs": [], - "source": [ - "n_pix = 100\n", - "res = 0.05\n", - "upsample_factor = 2\n", - "fov = res * n_pix\n", - "thx, thy = caustics.utils.meshgrid(\n", - " res / upsample_factor,\n", - " upsample_factor * n_pix,\n", - " dtype=torch.float32,\n", - ")\n", - "z_l = torch.tensor(0.5, dtype=torch.float32)\n", - "z_s = torch.tensor(1.5, dtype=torch.float32)\n", - "cosmology = caustics.FlatLambdaCDM(name=\"cosmo\")\n", - "cosmology.to(dtype=torch.float32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "# SIE lens model, kappa map, alpha map, magnification, time delay, caustics\n", - "\n", - "\n", - "def plot_lens_metrics(thx0, thy0, q, phi, b):\n", - " lens = caustics.SIE(\n", - " cosmology=cosmology,\n", - " z_l=z_l,\n", - " x0=thx0,\n", - " y0=thy0,\n", - " q=q,\n", - " phi=phi,\n", - " b=b,\n", - " )\n", - " fig, axarr = plt.subplots(2, 3, figsize=(9, 6))\n", - " kappa = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0][0].imshow(torch.log10(kappa), origin=\"lower\")\n", - " axarr[0][0].set_title(\"log(convergence)\")\n", - " psi = avg_pool2d(lens.potential(thx, thy, z_s)[None, None, :, :], upsample_factor)[\n", - " 0, 0\n", - " ]\n", - " axarr[0][1].imshow(psi, origin=\"lower\")\n", - " axarr[0][1].set_title(\"potential\")\n", - " timedelay = avg_pool2d(\n", - " lens.time_delay(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0][2].imshow(timedelay, origin=\"lower\")\n", - " axarr[0][2].set_title(\"time delay\")\n", - " magnification = avg_pool2d(\n", - " lens.magnification(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[1][0].imshow(torch.log10(magnification), origin=\"lower\")\n", - " axarr[1][0].set_title(\"log(magnification)\")\n", - " alpha = lens.reduced_deflection_angle(thx, thy, z_s)\n", - " alpha0 = avg_pool2d(alpha[0][None, None, :, :], upsample_factor)[0, 0]\n", - " alpha1 = avg_pool2d(alpha[1][None, None, :, :], upsample_factor)[0, 0]\n", - " axarr[1][1].imshow(alpha0, origin=\"lower\")\n", - " axarr[1][1].set_title(\"deflection angle x\")\n", - " axarr[1][2].imshow(alpha1, origin=\"lower\")\n", - " axarr[1][2].set_title(\"deflection angle y\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "p = interact(\n", - " plot_lens_metrics,\n", - " thx0=(-2.5, 2.5, 0.1),\n", - " thy0=(-2.5, 2.5, 0.1),\n", - " q=(0.01, 0.99, 0.01),\n", - " phi=(0.0, np.pi, np.pi / 25),\n", - " b=(0.1, 2.0, 0.1),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "# Sersic source, demo lensed source\n", - "def plot_lens_distortion(\n", - " x0_lens,\n", - " y0_lens,\n", - " q_lens,\n", - " phi_lens,\n", - " b_lens,\n", - " x0_src,\n", - " y0_src,\n", - " q_src,\n", - " phi_src,\n", - " n_src,\n", - " Re_src,\n", - " Ie_src,\n", - "):\n", - " lens = caustics.SIE(\n", - " cosmology,\n", - " z_l,\n", - " x0=x0_lens,\n", - " y0=y0_lens,\n", - " q=q_lens,\n", - " phi=phi_lens,\n", - " b=b_lens,\n", - " )\n", - " source = caustics.Sersic(\n", - " x0=x0_src,\n", - " y0=y0_src,\n", - " q=q_src,\n", - " phi=phi_src,\n", - " n=n_src,\n", - " Re=Re_src,\n", - " Ie=Ie_src,\n", - " )\n", - " fig, axarr = plt.subplots(1, 3, figsize=(18, 6))\n", - " brightness = avg_pool2d(\n", - " source.brightness(thx, thy)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[0].imshow(brightness, origin=\"lower\")\n", - " axarr[0].set_title(\"Sersic source\")\n", - " kappa = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[1].imshow(torch.log10(kappa), origin=\"lower\")\n", - " axarr[1].set_title(\"lens log(convergence)\")\n", - " beta_x, beta_y = lens.raytrace(thx, thy, z_s)\n", - " mu = avg_pool2d(\n", - " source.brightness(beta_x, beta_y)[None, None, :, :], upsample_factor\n", - " )[0, 0]\n", - " axarr[2].imshow(mu, origin=\"lower\")\n", - " axarr[2].set_title(\"Sersic lensed\")\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "p = interact(\n", - " plot_lens_distortion,\n", - " x0_lens=(-2.5, 2.5, 0.1),\n", - " y0_lens=(-2.5, 2.5, 0.1),\n", - " q_lens=(0.01, 0.99, 0.01),\n", - " phi_lens=(0.0, np.pi, np.pi / 25),\n", - " b_lens=(0.1, 2.0, 0.1),\n", - " x0_src=(-2.5, 2.5, 0.1),\n", - " y0_src=(-2.5, 2.5, 0.1),\n", - " q_src=(0.01, 0.99, 0.01),\n", - " phi_src=(0.0, np.pi, np.pi / 25),\n", - " n_src=(0.5, 4, 0.1),\n", - " Re_src=(0.1, 2, 0.1),\n", - " Ie_src=(0.1, 2.0, 0.1),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorials/VisualizeCaustics.ipynb b/docs/source/tutorials/VisualizeCaustics.ipynb index ec758794..d487a2fe 100644 --- a/docs/source/tutorials/VisualizeCaustics.ipynb +++ b/docs/source/tutorials/VisualizeCaustics.ipynb @@ -113,7 +113,11 @@ "cell_type": "code", "execution_count": null, "id": "6", - "metadata": {}, + "metadata": { + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ "# Get the path from the matplotlib contour plot of the critical line\n", diff --git a/docs/source/tutorials/example.yml b/docs/source/tutorials/example.yml index c9a457f9..c161d782 100644 --- a/docs/source/tutorials/example.yml +++ b/docs/source/tutorials/example.yml @@ -26,3 +26,4 @@ simulator: lens_light: *lnslt pixelscale: 0.05 pixels_x: 100 + quad_level: 3 diff --git a/docs/source/websim.md b/docs/source/websim.md new file mode 100644 index 00000000..377f8a4d --- /dev/null +++ b/docs/source/websim.md @@ -0,0 +1,28 @@ +# Live Caustics Playground + +Here is a little online simulator run using caustics! It's slow because it's +running on a free server, but it's a good way to play with the simulator without +having to install anything. + +Pro tip: check out the "Pixelated" source to lens any image you want! + +