diff --git a/deeplens/diffraclens.py b/deeplens/diffraclens.py index 42a6b40..b48389c 100644 --- a/deeplens/diffraclens.py +++ b/deeplens/diffraclens.py @@ -1,68 +1,96 @@ -"""Pure diffractive lens with all surface represented for wave optics. - -1. Thinlens + DOE + Sensor -2. Real lens + DOE + Sensor -""" +"""Refractive-diffractive lens with all surface represented using paraxial wave optics.""" +import json import matplotlib.pyplot as plt -import numpy as np import torch import torch.nn.functional as F -from torchvision.utils import make_grid, save_image - -from .optics.basics import ( - BLUE_RESPONSE, - DEPTH, - DEVICE, - GREEN_RESPONSE, - RED_RESPONSE, - WAVE_BOARD_BAND, - WAVE_RGB, - DeepObj, -) -from .optics.surfaces_diffractive import DOE, Sensor, ThinLens +from torchvision.utils import save_image + +from .optics.basics import DEPTH +from .optics.surfaces_diffractive import DOE, Sensor, ThinLens, Aperture from .optics.waveoptics_utils import point_source_field +from .lens import Lens -class DoeThinLens(DeepObj): - """DOE with thin lens model. +class DoeThinLens(Lens): + """Paraxial refractive-diffractive lens model. The lens consists of a thin lens, a diffractive optical element (DOE), an aperture, and a sensor. All optical surfaces are modeled using paraxial wave optics. - Reference: https://www.lighttrans.com/use-cases/application/chromatic-aberration-correction.html + This paraxial model is used in most existing computational imaging papers. It is a good approximation for the lens with a small field of view and a small aperture. The model can optimize the DOE but not the thin lens. """ - def __init__(self, thinlens=None, doe=None, sensor=None, aper=None, device=DEVICE): - super().__init__() - - self.aper = aper - self.thinlens = thinlens - self.doe = doe - self.sensor = sensor - - if aper is None: - self.surfaces = [self.thinlens, self.doe, self.sensor] - else: - self.surfaces = [self.aper, self.thinlens, self.doe, self.sensor] - - assert doe.l == sensor.l, "DOE and sensor should have the same physical size." - self.to(device) - - def load_example(self): - """Generate an example lens group. The lens is focused to infinity.""" - self.thinlens = ThinLens(foclen=50, d=0, r=12.7) - self.doe = DOE(d=40, l=4, res=1024) - self.sensor = Sensor(d=50, l=4) - - self.surfaces = [self.thinlens, self.doe, self.sensor] + def __init__(self, filename, sensor_res=[1024, 1024]): + super().__init__(filename, sensor_res) + + def load_file(self, filename): + """Load lens from a file.""" + self.surfaces = [] + with open(filename, "r") as f: + data = json.load(f) + d = 0.0 + for surf_dict in data["surfaces"]: + if surf_dict["type"] == "Aperture": + s = Aperture(r=surf_dict["r"], d=d) + self.aperture = s + + elif surf_dict["type"] == "DOE": + s = DOE( + l=surf_dict["l"], + d=d, + res=surf_dict["res"], + fab_ps=surf_dict["fab_ps"], + param_model=surf_dict["param_model"], + ) + if surf_dict["doe_path"] is not None: + s.load_doe(surf_dict["doe_path"]) + self.doe = s + + elif surf_dict["type"] == "ThinLens": + s = ThinLens(foclen=surf_dict["foclen"], r=surf_dict["r"], d=d) + self.thinlens = s + + elif surf_dict["type"] == "Sensor": + s = Sensor(l=surf_dict["l"], d=d, res=surf_dict["res"]) + self.sensor = s + + else: + raise Exception("Surface type not implemented.") + + self.surfaces.append(s) + + if not surf_dict["type"] == "Sensor": + d += surf_dict["d_next"] + + self.lens_info = data["info"] + + def write_file(self, filename): + """Write the lens into a file.""" + # Save DOE to a file + doe_filename = filename.replace(".json", "_doe.pth") + self.doe.save_doe(doe_filename) + + # Save lens to a file + data = {} + data["info"] = self.lens_info if hasattr(self, "lens_info") else "None" + data["surfaces"] = [] + for i, s in enumerate(self.surfaces): + surf_dict = {"idx": i + 1} + + surf_dict = s.surf_dict() + if isinstance(s, DOE): + surf_dict["doe_path"] = doe_filename + surf_dict.update(surf_dict) + + if i < len(self.surfaces) - 1: + surf_dict["d_next"] = ( + self.surfaces[i + 1].d.item() - self.surfaces[i].d.item() + ) - def load_example2(self): - """Generate an example lens group. The lens is focused to -100mm.""" - self.thinlens = ThinLens(foclen=50, d=0, r=12.7) - self.doe = DOE(d=40, l=4, res=1024) - self.sensor = Sensor(d=100, l=4) + data["surfaces"].append(surf_dict) - self.surfaces = [self.thinlens, self.doe, self.sensor] + with open(filename, "w") as f: + json.dump(data, f, indent=4) - def forward(self, field): + def prop_wave(self, field): """Propagate a wavefront through the lens group. Args: @@ -77,7 +105,7 @@ def forward(self, field): return field # ============================================= - # PSF related functions + # PSF-related functions # ============================================= def psf(self, point=[0, 0, -10000.0], ks=101, wvln=0.589): """Calculate monochromatic point PSF by wave propagation approach. @@ -108,12 +136,11 @@ def psf(self, point=[0, 0, -10000.0], ks=101, wvln=0.589): res=field_res, wvln=wvln, fieldz=self.surfaces[0].d.item(), + device=self.device, ) - # Calculate PSF on the sensor. - psf_full_res = self.forward(inp_field)[ - 0, 0, :, : - ] # shape of [H_sensor, W_sensor] + # Calculate PSF on the sensor. Shape [H_sensor, W_sensor] + psf_full_res = self.prop_wave(inp_field)[0, 0, :, :] # Crop the valid patch of the full-resolution PSF coord_c_i = int((1 + y) * sensor.res[0] / 2) @@ -129,201 +156,279 @@ def psf(self, point=[0, 0, -10000.0], ks=101, wvln=0.589): return psf_out - def psf_rgb(self, point=[0, 0, -DEPTH], ks=101): - """Calculate RGB point PSF of DOEThinLens. + def draw_psf(self, depth=DEPTH, ks=101, save_name="./psf_doethinlens.png"): + """Draw on-axis RGB PSF.""" + psf_rgb = self.psf_rgb(point=[0, 0, depth], ks=ks) + save_image(psf_rgb.unsqueeze(0), save_name, normalize=True) - Args: - point (list, optional): Point source position. Defaults to [0, 0, -DEPTH]. - ks (int, optional): PSF kernel size. Defaults to 101. + # ============================================= + # Utils + # ============================================= + def get_optimizer(self, lr): + return self.doe.get_optimizer(lr=lr) - Returns: - psf (tensor): RGB PSF. Shape of [3, ks, ks] - """ - psf = [] - for wvln in WAVE_RGB: - psf_mono = self.psf(point=point, ks=ks, wvln=wvln) - psf.append(psf_mono) + def draw_layout(self, save_name="./doethinlens.png"): + """Draw lens setup.""" + fig, ax = plt.subplots() - psf = torch.stack(psf, dim=0) # shape [3, ks, ks] - return psf + # Draw aperture + d = self.aperture.d.item() + r = self.aperture.r + ax.plot([d, d], [r, r + 0.5], "gray") + ax.plot([d - 0.5, d + 0.5], [r, r], "gray") # top wedge + ax.plot([d, d], [-r, -r - 0.5], "gray") + ax.plot([d - 0.5, d + 0.5], [-r, -r], "gray") # bottom wedge - def psf_board_band(self, point=[0, 0, -DEPTH], ks=101): - """Calculate boardband RGB PSF""" - psf_r = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(point=point, ks=ks, wvln=wvln) - psf_r.append(psf * RED_RESPONSE[i]) - psf_r = torch.stack(psf_r, dim=0).sum(dim=0) / sum(RED_RESPONSE) + # Draw thinlens + d = self.thinlens.d.item() + r = self.thinlens.r + arrow_length = r + ax.arrow( + d, + -arrow_length, + 0, + 2 * arrow_length, + head_width=0.5, + head_length=0.5, + fc="black", + ec="black", + length_includes_head=True, + ) + ax.arrow( + d, + arrow_length, + 0, + -2 * arrow_length, + head_width=0.5, + head_length=0.5, + fc="black", + ec="black", + length_includes_head=True, + ) + + # Draw DOE + d = self.doe.d.item() + doe_l = self.doe.l + ax.plot( + [d, d], [-doe_l / 2, doe_l / 2], "orange", linestyle="--", dashes=[1, 1] + ) - psf_g = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(point=point, ks=ks, wvln=wvln) - psf_g.append(psf * GREEN_RESPONSE[i]) - psf_g = torch.stack(psf_g, dim=0).sum(dim=0) / sum(GREEN_RESPONSE) + # Draw sensor + d = self.sensor.d.item() + sensor_l = self.sensor.l + width = 0.2 # Width of the rectangle + rect = plt.Rectangle( + (d - width / 2, -sensor_l / 2), + width, + sensor_l, + facecolor="none", + edgecolor="black", + linewidth=1, + ) + ax.add_patch(rect) - psf_b = [] - for i, wvln in enumerate(WAVE_BOARD_BAND): - psf = self.psf(point=point, ks=ks, wvln=wvln) - psf_b.append(psf * BLUE_RESPONSE[i]) - psf_b = torch.stack(psf_b, dim=0).sum(dim=0) / sum(BLUE_RESPONSE) + ax.set_aspect("equal") + ax.axis("off") + fig.savefig(save_name, dpi=600, bbox_inches="tight") + plt.close(fig) - psfs = torch.stack([psf_r, psf_g, psf_b], dim=0) # shape [3, ks, ks] - return psfs +class DoeLens(Lens): + """Paraxial diffractive lens model. The lens consists of a diffractive optical element (DOE) and a sensor. DOE is modeled using paraxial wave optics.""" - def psf_map(self, grid=9, ks=101, wvln=0.589, depth=DEPTH): - """Generate PSF map for DoeThinlens. + def __init__(self, filename, sensor_res=[1024, 1024]): + super().__init__(filename, sensor_res) - Args: - grid (int, optional): Grid size. Defaults to 9. - ks (int, optional): PSF kernel size. Defaults to 101. - wvln (float, optional): wvln. Defaults to 0.589. - depth (float, optional): Depth of the point source. Defaults to DEPTH. + def load_example(self): + self.doe = DOE(d=0, l=4, res=4000) + self.doe.init_param_model(param_model="fresnel", f0=50, fresnel_wvln=0.589) + self.sensor = Sensor(d=50, l=4) + self.surfaces = [self.doe, self.sensor] + + def load_file(self, filename): + """Load lens from a file.""" + self.surfaces = [] + with open(filename, "r") as f: + data = json.load(f) + d = 0.0 + for surf_dict in data["surfaces"]: + if surf_dict["type"] == "DOE": + s = DOE( + l=surf_dict["l"], + d=d, + res=surf_dict["res"], + fab_ps=surf_dict["fab_ps"], + param_model=surf_dict["param_model"], + ) + if surf_dict["doe_path"] is not None: + s.load_doe(surf_dict["doe_path"]) + self.doe = s + + elif surf_dict["type"] == "Sensor": + s = Sensor(l=surf_dict["l"], d=d, res=surf_dict["res"]) + self.sensor = s + + else: + raise Exception("Surface type not implemented.") + + self.surfaces.append(s) + + if not surf_dict["type"] == "Sensor": + d += surf_dict["d_next"] + + self.lens_info = data["info"] + + def write_file(self, filename): + """Write the lens into a file.""" + # Save DOE to a file + doe_filename = filename.replace(".json", "_doe.pth") + self.doe.save_doe(doe_filename) + + # Save lens to a file + data = {} + data["info"] = self.lens_info if hasattr(self, "lens_info") else "None" + data["surfaces"] = [] + for i, s in enumerate(self.surfaces): + surf_dict = {"idx": i + 1} + + surf_dict = s.surf_dict() + if isinstance(s, DOE): + surf_dict["doe_path"] = doe_filename + surf_dict.update(surf_dict) + + if i < len(self.surfaces) - 1: + surf_dict["d_next"] = ( + self.surfaces[i + 1].d.item() - self.surfaces[i].d.item() + ) - Returns: - psf_map (tensor): PSF map. Shape of [grid*ks, grid*ks] - """ - points = self.point_source_grid(depth=depth, grid=grid, center=True) - points = points.reshape(-1, 3) + data["surfaces"].append(surf_dict) - psfs = [] - for i in range(points.shape[0]): - point = points[i, ...] - psf = self.psf(point=point, ks=ks, wvln=wvln) - psfs.append(psf) + with open(filename, "w") as f: + json.dump(data, f, indent=4) - psf_map = torch.stack(psfs).unsqueeze(1) - psf_map = make_grid(psf_map, nrow=grid, padding=0)[0, :, :] + def prop_wave(self, field): + """Propagate a wavefront through the lens group. - return psf_map + Args: + field (Field): Input wavefront. - def point_source_grid( - self, depth, grid=9, normalized=True, quater=False, center=False - ): - """ - Generate point source grid for PSF calculation. + Returns: + field (torch.tensor): Output energy distribution. Shape of [H_sensor, W_sensor] """ - # ==> Use center of each patch - if grid == 1: - x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]]) - assert not quater, "Quater should be False when grid is 1." - else: - if center: - half_bin_size = 1 / 2 / (grid - 1) - x, y = torch.meshgrid( - torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid), - torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid), - indexing="xy", - ) - # ==> Use corner - else: - x, y = torch.meshgrid( - torch.linspace(-0.98, 0.98, grid), - torch.linspace(0.98, -0.98, grid), - indexing="xy", - ) + for surf in self.surfaces: + field = surf(field) - z = torch.full((grid, grid), depth) - point_source = torch.stack([x, y, z], dim=-1) + return field - # ==> Use quater of the sensor plane to save memory - if quater: - z = torch.full((grid, grid), depth) - point_source = torch.stack([x, y, z], dim=-1) - bound_i = grid // 2 if grid % 2 == 0 else grid // 2 + 1 - bound_j = grid // 2 - point_source = point_source[0:bound_i, bound_j:, :] + # ============================================= + # PSF-related functions + # ============================================= + def psf(self, point=[0, 0, -10000.0], ks=101, wvln=0.589): + """Calculate monochromatic point PSF by wave propagation approach. - if not normalized: - scale = self.calc_scale_pinhole(depth) - point_source[..., 0] *= scale * self.sensor_size[0] / 2 - point_source[..., 1] *= scale * self.sensor_size[1] / 2 + For the shifted phase issue, refer to "Modeling off-axis diffraction with the least-sampling angular spectrum method". - return point_source + Args: + point (list, optional): Point source position. Defaults to [0, 0, -10000]. + ks (int, optional): PSF kernel size. Defaults to 256. + wvln (float, optional): wvln. Defaults to 0.55. + padding (bool, optional): Whether to pad the PSF. Defaults to True. - def point_source_radial(self, depth, grid=9, normalized=True, center=False): - """ - Generate radial point source grid for PSF calculation. + Returns: + psf_out (tensor): PSF. shape [ks, ks] """ - if grid == 1: - x = torch.tensor([0.0]) - else: - # Select center of bin to calculate PSF - if center: - half_bin_size = 1 / 2 / (grid - 1) - x = torch.linspace(0, 1 - half_bin_size, grid) - else: - x = torch.linspace(0, 0.98, grid) - - z = torch.full_like(x, depth) - point_source = torch.stack([x, x, z], dim=-1) - return point_source + # Get input field + x, y, z = point + sensor = self.sensor + sensor_l = sensor.l + field_res = self.doe.res + scale = -z / sensor.d.item() + x_obj, y_obj = x * scale * sensor_l / 2, y * scale * sensor_l / 2 - def draw_psf(self, depth=DEPTH, ks=101, save_name="./psf_doethinlens.png"): - """Draw RGB on-axis PSF.""" - psfs = [] - for wvln in WAVE_RGB: - psf = self.psf(point=[0, 0, depth], ks=ks, wvln=wvln) - psfs.append(psf) + # We have to sample high resolution to meet Nyquist sampling constraint. + inp_field = point_source_field( + point=[x_obj, y_obj, z], + phy_size=[sensor_l, sensor_l], + res=field_res, + wvln=wvln, + fieldz=self.surfaces[0].d.item(), + device=self.device, + ) + + # Calculate PSF on the sensor. Shape [H_sensor, W_sensor] + psf_full_res = self.prop_wave(inp_field)[0, 0, :, :] + + # Crop the valid patch of the full-resolution PSF + coord_c_i = int((1 + y) * sensor.res[0] / 2) + coord_c_j = int((1 - x) * sensor.res[1] / 2) + psf_full_res = F.pad( + psf_full_res, [ks // 2, ks // 2, ks // 2, ks // 2], mode="constant", value=0 + ) + psf_out = psf_full_res[coord_c_i : coord_c_i + ks, coord_c_j : coord_c_j + ks] - psfs = torch.stack(psfs, dim=0) # shape [3, ks, ks] - save_image(psfs.unsqueeze(0), save_name, normalize=True) + # Normalize PSF + psf_out /= psf_out.sum() + psf_out = torch.flip(psf_out, [0, 1]) - def draw_psf_map( + return psf_out + + def draw_psf( self, - grid=8, depth=DEPTH, ks=101, - log_scale=False, - save_name="./psf_map_doethinlens.png", + save_name="./psf_doelens.png", + log_scale=True, + eps=1e-4, ): - """Draw RGB PSF map of the DOE thin lens.""" - # Calculate PSF map - psf_maps = [] - for wvln in WAVE_RGB: - psf_map = self.psf_map(grid=grid, depth=depth, ks=ks, wvln=wvln) - psf_maps.append(psf_map) - psf_map = torch.stack(psf_maps, dim=0) # shape [3, grid*ks, grid*ks] - - # Data processing for visualization - if log_scale: - psf_map = torch.log(psf_map + 1e-4) - psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min()) + """ + Draw on-axis RGB PSF. - save_image(psf_map.unsqueeze(0), save_name, normalize=True) + Args: + depth (float): Depth of the point source + ks (int): Size of the PSF kernel + save_name (str): Path to save the PSF image + log_scale (bool): If True, display PSF in log scale + """ + psf_rgb = self.psf_rgb(point=[0, 0, depth], ks=ks) - def get_optimizer(self, lr): - return self.doe.get_optimizer(lr=lr) + if log_scale: + psf_rgb = torch.log10(psf_rgb + eps) + psf_rgb = (psf_rgb - psf_rgb.min()) / (psf_rgb.max() - psf_rgb.min()) + save_name = save_name.replace(".png", "_log.png") + + save_image(psf_rgb.unsqueeze(0), save_name, normalize=True) # ============================================= # Utils # ============================================= - def draw_layout(self, save_name="./doethinlens.png"): + def get_optimizer(self, lr): + return self.doe.get_optimizer(lr=lr) + + def draw_layout(self, save_name="./doelens.png"): """Draw lens setup.""" fig, ax = plt.subplots() - # Plot thin lens - d = self.thinlens.d.item() - r = self.thinlens.r - roc = r * 2 # A manually set roc for plotting - r_ls = np.arange(-r, r, 0.01) - d1_ls = d - (np.sqrt(roc**2 - r_ls**2) - np.sqrt(roc**2 - r**2)) - d2_ls = d + (np.sqrt(roc**2 - r_ls**2) - np.sqrt(roc**2 - r**2)) - ax.plot(d1_ls, r_ls, "black") - ax.plot(d2_ls, r_ls, "black") - - # Plot DOE + # Draw DOE d = self.doe.d.item() - l = self.doe.l - ax.plot([d, d], [-l / 2, l / 2], "black") + doe_l = self.doe.l + ax.plot( + [d, d], [-doe_l / 2, doe_l / 2], "orange", linestyle="--", dashes=[1, 1] + ) - # Plot sensor + # Draw sensor d = self.sensor.d.item() - l = self.sensor.l - ax.plot([d, d], [-l / 2, l / 2], "black") + sensor_l = self.sensor.l + width = 0.2 # Width of the rectangle + rect = plt.Rectangle( + (d - width / 2, -sensor_l / 2), + width, + sensor_l, + facecolor="none", + edgecolor="black", + linewidth=1, + ) + ax.add_patch(rect) - # ax.set_xlim(-1, 100) - # ax.set_ylim(-100, 100) ax.set_aspect("equal") ax.axis("off") fig.savefig(save_name, dpi=600, bbox_inches="tight") diff --git a/deeplens/lens.py b/deeplens/lens.py index 30e5a03..59a60c2 100644 --- a/deeplens/lens.py +++ b/deeplens/lens.py @@ -4,6 +4,7 @@ """ import torch +from torchvision.utils import make_grid, save_image from .optics import ( DeepObj, init_device, @@ -16,26 +17,25 @@ GREEN_RESPONSE, BLUE_RESPONSE, DEPTH, - render_psf_map, ) class Lens(DeepObj): """Geolens class. A geometric lens consisting of refractive surfaces, simulate with ray tracing. May contain diffractive surfaces, but still use ray tracing to simulate.""" - def __init__(self, filename=None, sensor_res=[1024, 1024]): + def __init__(self, filename, sensor_res=[1024, 1024]): """A lens class.""" super(Lens, self).__init__() self.device = init_device() # Load lens file self.lens_name = filename - self.load_file(filename, sensor_res) + self.load_file(filename) self.to(self.device) - # Lens calculation - self.prepare_sensor(sensor_res) - self.post_computation() + # # Lens calculation + # self.prepare_sensor(sensor_res) + # self.post_computation() def load_file(self, filename): """Load lens from a file.""" @@ -58,62 +58,173 @@ def post_computation(self): # =========================================== def psf(self, points, ks=51, wvln=0.589, **kwargs): """Compute monochrome point PSF. This function is differentiable.""" - pass + raise NotImplementedError - def psf_rgb(self, points, ks=51, **kwargs): + def psf_rgb(self, point, ks=51, **kwargs): """Compute RGB point PSF. This function is differentiable.""" psfs = [] for wvln in WAVE_RGB: - psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs)) + psfs.append(self.psf(point=point, ks=ks, wvln=wvln, **kwargs)) psf_rgb = torch.stack(psfs, dim=-3) # shape [3, ks, ks] or [N, 3, ks, ks] return psf_rgb def psf_narrow_band(self, points, ks=51, **kwargs): - """Should be migrated to psf_rgb.""" + """Should be migrated to psf_rgb. + + In this function we use an average for different wavelengths. Actually we should use the sensor response function. + """ + # Red psf_r = [] - for i, wvln in enumerate(WAVE_RED): + for _, wvln in enumerate(WAVE_RED): psf_r.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs)) psf_r = torch.stack(psf_r, dim=-3).mean(dim=-3) + # Green psf_g = [] - for i, wvln in enumerate(WAVE_GREEN): + for _, wvln in enumerate(WAVE_GREEN): psf_g.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs)) psf_g = torch.stack(psf_g, dim=-3).mean(dim=-3) + # Blue psf_b = [] - for i, wvln in enumerate(WAVE_BLUE): + for _, wvln in enumerate(WAVE_BLUE): psf_b.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs)) psf_b = torch.stack(psf_b, dim=-3).mean(dim=-3) + # RGB psf = torch.stack([psf_r, psf_g, psf_b], dim=-3) return psf def psf_spectrum(self, points, ks=51, **kwargs): """Should be migrated to psf_rgb.""" + # Red psf_r = [] for i, wvln in enumerate(WAVE_BOARD_BAND): psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs) psf_r.append(psf * RED_RESPONSE[i]) psf_r = torch.stack(psf_r, dim=0).sum(dim=0) / sum(RED_RESPONSE) + # Green psf_g = [] for i, wvln in enumerate(WAVE_BOARD_BAND): psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs) psf_g.append(psf * GREEN_RESPONSE[i]) psf_g = torch.stack(psf_g, dim=0).sum(dim=0) / sum(GREEN_RESPONSE) + # Blue psf_b = [] for i, wvln in enumerate(WAVE_BOARD_BAND): psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs) psf_b.append(psf * BLUE_RESPONSE[i]) psf_b = torch.stack(psf_b, dim=0).sum(dim=0) / sum(BLUE_RESPONSE) + # RGB psf = torch.stack([psf_r, psf_g, psf_b], dim=0) # shape [3, ks, ks] return psf - def psf_map(self, grid=21, ks=51, depth=-20000.0, **kwargs): + def draw_psf(self, depth=DEPTH, ks=101, save_name="./psf.png"): + """Draw RGB on-axis PSF.""" + psfs = [] + for wvln in WAVE_RGB: + psf = self.psf(point=[0, 0, depth], ks=ks, wvln=wvln) + psfs.append(psf) + + psfs = torch.stack(psfs, dim=0) # shape [3, ks, ks] + save_image(psfs.unsqueeze(0), save_name, normalize=True) + + def point_source_grid( + self, depth, grid=9, normalized=True, quater=False, center=False + ): + """ + Generate point source grid for PSF calculation. + """ + # ==> Use center of each patch + if grid == 1: + x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]]) + assert not quater, "Quater should be False when grid is 1." + else: + if center: + half_bin_size = 1 / 2 / (grid - 1) + x, y = torch.meshgrid( + torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid), + torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid), + indexing="xy", + ) + # ==> Use corner + else: + x, y = torch.meshgrid( + torch.linspace(-0.98, 0.98, grid), + torch.linspace(0.98, -0.98, grid), + indexing="xy", + ) + + z = torch.full((grid, grid), depth) + point_source = torch.stack([x, y, z], dim=-1) + + # ==> Use quater of the sensor plane to save memory + if quater: + z = torch.full((grid, grid), depth) + point_source = torch.stack([x, y, z], dim=-1) + bound_i = grid // 2 if grid % 2 == 0 else grid // 2 + 1 + bound_j = grid // 2 + point_source = point_source[0:bound_i, bound_j:, :] + + if not normalized: + raise Exception("Need to specify the scale.") + scale = self.calc_scale_pinhole(depth) + point_source[..., 0] *= scale * self.sensor_size[0] / 2 + point_source[..., 1] *= scale * self.sensor_size[1] / 2 + + return point_source + + def psf_map(self, grid=21, ks=51, depth=-20000.0, wvln=0.589, **kwargs): """Compute PSF map.""" - pass + # raise NotImplementedError + points = self.point_source_grid(depth=depth, grid=grid, center=True) + points = points.reshape(-1, 3) + + psfs = [] + for i in range(points.shape[0]): + point = points[i, ...] + psf = self.psf(point=point, ks=ks, wvln=wvln) + psfs.append(psf) + + psf_map = torch.stack(psfs).unsqueeze(1) + psf_map = make_grid(psf_map, nrow=grid, padding=0)[0, :, :] + + return psf_map + + def psf_map_rgb(self, grid=21, ks=51, depth=-20000.0, **kwargs): + """Compute RGB PSF map.""" + psfs = [] + for wvln in WAVE_RGB: + psf_map = self.psf_map(grid=grid, ks=ks, depth=depth, wvln=wvln, **kwargs) + psfs.append(psf_map) + psf_map = torch.stack(psfs, dim=0) # shape [3, grid*ks, grid*ks] + return psf_map + + def draw_psf_map( + self, + grid=8, + depth=DEPTH, + ks=101, + log_scale=False, + save_name="./psf_map.png", + ): + """Draw RGB PSF map of the DOE thin lens.""" + # Calculate PSF map + psf_maps = [] + for wvln in WAVE_RGB: + psf_map = self.psf_map(grid=grid, depth=depth, ks=ks, wvln=wvln) + psf_maps.append(psf_map) + psf_map = torch.stack(psf_maps, dim=0) # shape [3, grid*ks, grid*ks] + + # Data processing for visualization + if log_scale: + psf_map = torch.log(psf_map + 1e-4) + psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min()) + + save_image(psf_map.unsqueeze(0), save_name, normalize=True) # =========================================== # Rendering-ralated functions diff --git a/deeplens/optics/materials.py b/deeplens/optics/materials.py index 3520f62..ddab5cc 100644 --- a/deeplens/optics/materials.py +++ b/deeplens/optics/materials.py @@ -9,6 +9,12 @@ def __init__(self, name=None): self.name = "vacuum" if name is None else name.lower() self.load_dispersion() + def get_name(self): + if self.dispersion == "optimizable": + return f"{self.n.item():.4f}/{self.V.item():.2f}" + else: + return self.name + def load_dispersion(self): """Load material dispersion equation.""" if self.name in SELLMEIER_TABLE: @@ -115,7 +121,7 @@ def match_material(self): self.load_dispersion() - def get_optimizer_params(self, lr=1e-3): + def get_optimizer_params(self, lr=[1e-4, 1e-2]): """Optimize the material parameters (n, V).""" self.n = torch.tensor(self.n).to(self.device) self.V = torch.tensor(self.V).to(self.device) @@ -124,7 +130,7 @@ def get_optimizer_params(self, lr=1e-3): self.V.requires_grad = True self.dispersion = "optimizable" - params = {"params": [self.A, self.B], "lr": lr} + params = [{"params": [self.n], "lr": lr[0]}, {"params": [self.V], "lr": lr[1]}] return params diff --git a/deeplens/optics/surfaces.py b/deeplens/optics/surfaces.py index 5843508..618d8d7 100644 --- a/deeplens/optics/surfaces.py +++ b/deeplens/optics/surfaces.py @@ -354,7 +354,7 @@ def surf_dict(self): "r": float(f"{self.r:.6f}"), "(d)": float(f"{self.d.item():.3f}"), "is_square": self.is_square, - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } return surf_dict @@ -800,7 +800,7 @@ def get_optimizer_params( f"params.append({{'params': [self.ai{2*i}], 'lr': lr[3] * decay**{i-1}}})" ) - if optim_mat and self.mat2.name != "air": + if optim_mat and self.mat2.get_name() != "air": params += self.mat2.get_optimizer_params() return params @@ -836,7 +836,7 @@ def surf_dict(self): "(d)": float(f"{self.d.item():.3f}"), "k": float(f"{self.k.item():.6f}"), "ai": [], - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } for i in range(1, self.ai_degree + 1): exec(f"surf_dict['(ai{2*i})'] = self.ai{2*i}.item()") @@ -852,7 +852,7 @@ def zmx_str(self, surf_idx, d_next): assert ( self.ai is not None or self.k != 0 ), "Spheric surface is re-implemented in Spheric class." - if self.mat2.name == "air": + if self.mat2.get_name() == "air": zmx_str = f"""SURF {surf_idx} TYPE EVENASPH CURV {self.c.item()} @@ -870,7 +870,7 @@ def zmx_str(self, surf_idx, d_next): TYPE EVENASPH CURV {self.c.item()} DISZ {d_next.item()} - GLAS {self.mat2.name.upper()} 0 0 {self.mat2.n} {self.mat2.V} + GLAS {self.mat2.get_name().upper()} 0 0 {self.mat2.n} {self.mat2.V} DIAM {self.r * 2} PARM 1 {self.ai2.item()} PARM 2 {self.ai4.item()} @@ -989,7 +989,7 @@ def get_optimizer_params(self, lr, optim_mat=False): else: raise Exception("Unsupported cubic degree!") - if optim_mat and self.mat2.name != "air": + if optim_mat and self.mat2.get_name() != "air": params += self.mat2.get_optimizer_params() return params @@ -1428,7 +1428,7 @@ def surf_dict(self): "param_model": self.param_model, "f0": self.f0.item(), "(d)": float(f"{self.d.item():.3f}"), - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } elif self.param_model == "binary2": @@ -1442,7 +1442,7 @@ def surf_dict(self): "order6": self.order6.item(), "order8": self.order8.item(), "(d)": f"{float(self.d.item()):.3f}", - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } elif self.param_model == "poly1d": @@ -1458,7 +1458,7 @@ def surf_dict(self): "order6": self.order6.item(), "order7": self.order7.item(), "(d)": float(f"{self.d.item():.3f}"), - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } elif self.param_model == "grating": @@ -1470,7 +1470,7 @@ def surf_dict(self): "theta": self.theta.item(), "alpha": self.alpha.item(), "(d)": float(f"{self.d.item():.3f}"), - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } return surf_dict @@ -1575,7 +1575,7 @@ def surf_dict(self): "l": self.l, "(d)": float(f"{self.d.item():.3f}"), "is_square": True, - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } return surf_dict @@ -1642,7 +1642,7 @@ def get_optimizer_params(self, lr=[0.001, 0.001], optim_mat=False): params.append({"params": [self.c], "lr": lr[0]}) params.append({"params": [self.d], "lr": lr[1]}) - if optim_mat and self.mat2.name != "air": + if optim_mat and self.mat2.get_name() != "air": params += self.mat2.get_optimizer_params() return params @@ -1656,14 +1656,14 @@ def surf_dict(self): "(c)": float(f"{self.c.item():.3f}"), "roc": float(f"{roc:.3f}"), "(d)": float(f"{self.d.item():.3f}"), - "mat2": self.mat2.name, + "mat2": self.mat2.get_name(), } return surf_dict def zmx_str(self, surf_idx, d_next): """Return Zemax surface string.""" - if self.mat2.name == "air": + if self.mat2.get_name() == "air": zmx_str = f"""SURF {surf_idx} TYPE STANDARD CURV {self.c.item()} @@ -1675,7 +1675,7 @@ def zmx_str(self, surf_idx, d_next): TYPE STANDARD CURV {self.c.item()} DISZ {d_next.item()} - GLAS {self.mat2.name.upper()} 0 0 {self.mat2.n} {self.mat2.V} + GLAS {self.mat2.get_name().upper()} 0 0 {self.mat2.n} {self.mat2.V} DIAM {self.r*2} """ return zmx_str diff --git a/deeplens/optics/surfaces_diffractive.py b/deeplens/optics/surfaces_diffractive.py index 6d0d857..fa28e28 100644 --- a/deeplens/optics/surfaces_diffractive.py +++ b/deeplens/optics/surfaces_diffractive.py @@ -15,9 +15,7 @@ class DOE(DeepObj): - def __init__( - self, l, d, res=None, fab_ps=0.001, param_model="pixel2d", device="cpu" - ): + def __init__(self, l, d, res, fab_ps=0.001, param_model="pixel2d", device="cpu"): """DOE class.""" super().__init__() @@ -77,7 +75,10 @@ def init_param_model(self, param_model="none", **kwargs): if self.param_model == "fresnel": # "Phase fresnel" or "Fresnel zone plate (FPZ)" - self.f0 = torch.tensor([100.0]) + f0 = kwargs.get("f0", 100.0) + self.f0 = torch.tensor([f0]) + fresnel_wvln = kwargs.get("fresnel_wvln", 0.55) + self.fresnel_wvln = fresnel_wvln elif self.param_model == "cubic": self.a3 = torch.tensor([0.001]) @@ -116,6 +117,10 @@ def init_param_model(self, param_model="none", **kwargs): self.to(self.device) + def save_doe(self, save_path="./doe.pth"): + """Save DOE phase map.""" + self.save_ckpt(save_path) + def save_ckpt(self, save_path="./doe.pth"): """Save DOE phase map.""" if self.param_model == "fresnel": @@ -177,6 +182,10 @@ def save_ckpt(self, save_path="./doe.pth"): else: raise Exception("Unknown parameterization.") + def load_doe(self, load_path="./doe_fab.pth"): + """Load DOE phase map.""" + self.load_ckpt(load_path) + def load_ckpt(self, load_path="./doe.pth"): """Load DOE phase map.""" ckpt = torch.load(load_path) @@ -248,7 +257,10 @@ def get_pmap(self): pmap = ( -2 * np.pi - * torch.fmod((self.x**2 + self.y**2) / (2 * 0.55e-3 * self.f0), 1) + * torch.fmod( + (self.x**2 + self.y**2) / (2 * self.fresnel_wvln * 1e-3 * self.f0), + 1, + ) ) # unit [mm] elif self.param_model == "cubic": @@ -606,9 +618,30 @@ def draw_cross_section(self, save_name="./DOE_corss_sec.png"): fig.savefig(save_name, dpi=600, bbox_inches="tight") plt.close(fig) + # ======================================= + # Utils + # ======================================= + def surf_dict(self): + """Return a dict of surface.""" + surf_dict = { + "type": "DOE", + "l": float(f"{self.l:.6f}"), + "res": self.res, + "fab_ps": float(f"{self.fab_ps:.6f}"), + "is_square": True, + "param_model": self.param_model, + "doe_path": None, + } + return surf_dict + class ThinLens(DeepObj): - """Paraxial optics, consisting of both thin lens and thick lens.""" + """Paraxial optical model for refractive lens. + + Two types of thin lenses supported: + (1) Thin lens without chromatic aberration. + (2) Extended thin lens with chromatic aberration. + """ def __init__(self, foclen, d, r): super().__init__() @@ -619,7 +652,7 @@ def __init__(self, foclen, d, r): self.chromatic = False def load_foclens(self, wvlns, foclens): - """Load a list of wvlns and corresponding focus lenghs for interpolation.""" + """Load a list of wvlns and corresponding focus lenghs for interpolation. This function is used for chromatic aberration simulation.""" self.chromatic = True self.ref_wvlns = wvlns self.ref_foclens = foclens @@ -664,11 +697,20 @@ def forward(self, field): return field + def surf_dict(self): + """Return a dict of surface.""" + surf_dict = { + "type": "ThinLens", + "foclen": float(f"{self.foclen:.6f}"), + "r": float(f"{self.r:.6f}"), + } + return surf_dict + class Aperture(DeepObj): def __init__(self, d, r, device="cpu"): super().__init__() - self.d = d + self.d = torch.tensor([d]) self.r = r def forward(self, field): @@ -679,6 +721,14 @@ def forward(self, field): return field + def surf_dict(self): + """Return a dict of surface.""" + surf_dict = { + "type": "Aperture", + "r": float(f"{self.r:.6f}"), + } + return surf_dict + class Sensor(DeepObj): def __init__(self, d, r=None, l=None, res=[2048, 2048], device="cpu"): @@ -726,6 +776,15 @@ def forward(self, field): return response + def surf_dict(self): + """Return a dict of surface.""" + surf_dict = { + "type": "Sensor", + "res": self.res, + "l": float(f"{self.l:.6f}"), + } + return surf_dict + def Zernike(z_coeff, grid=256): """Calculate phase map produced by the first 37 Zernike polynomials. The output zernike phase map is in real value, to use it in the future we need to convert it to complex value.""" diff --git a/deeplens/optics/wave.py b/deeplens/optics/wave.py index 4c6bb0b..1b5a3b6 100644 --- a/deeplens/optics/wave.py +++ b/deeplens/optics/wave.py @@ -69,8 +69,6 @@ def __init__( self.x, self.y = self.gen_xy_grid() self.z = torch.full_like(self.x, z) - self.to(device) - def load_img(self, img): """Use the pixel value of an image/batch as the amplitute. @@ -87,7 +85,7 @@ def load_img(self, img): phi = torch.zeros_like(amp) u = amp + 1j * phi - self.u = u.to(self.device) + # self.u = u.to(self.device) self.res = self.u.shape def load_pkl(self, data_path): @@ -105,8 +103,6 @@ def load_pkl(self, data_path): self.valid_phy_size = wave_data["valid_phy_size"] self.res = self.x.shape - self.to(self.device) - def save_data(self, save_path="./test.pkl"): data = { "amp": self.u.cpu().abs(), @@ -239,7 +235,7 @@ def gen_freq_grid(self): fy = y / (self.ps * self.phy_size[1]) return fx, fy - def show(self, data="irr", save_name=None): + def show(self, save_name=None, data="irr"): """Show the field.""" if data == "irr": value = self.u.detach().abs() ** 2 @@ -324,7 +320,6 @@ def Nyquist_zmin(self): def pad(self, Hpad, Wpad): """Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.""" - device = self.device # Pad directly self.u = F.pad(self.u, (Hpad, Hpad, Wpad, Wpad), mode="constant", value=0) diff --git a/deeplens/optics/waveoptics_utils.py b/deeplens/optics/waveoptics_utils.py index 49e4b2d..d2a998f 100644 --- a/deeplens/optics/waveoptics_utils.py +++ b/deeplens/optics/waveoptics_utils.py @@ -72,7 +72,12 @@ def plane_wave_field(phy_size, res, wvln=0.589, z=0.0): def point_source_field( - point=[0, 0, -1000.0], phy_size=[2, 2], res=[1024, 1024], wvln=0.589, fieldz=0 + point=[0, 0, -1000.0], + phy_size=[2, 2], + res=[1024, 1024], + wvln=0.589, + fieldz=0, + device="cpu", ): """Field on x0y plane generated by a point source. @@ -112,6 +117,7 @@ def point_source_field( u = r.min() / r * torch.exp(1j * phi) field = ComplexWave(u=u, wvln=wvln, phy_size=phy_size, res=res, z=fieldz) + field = field.to(device) return field diff --git a/lenses/paraxiallens/doelens.json b/lenses/paraxiallens/doelens.json new file mode 100644 index 0000000..0bcf812 --- /dev/null +++ b/lenses/paraxiallens/doelens.json @@ -0,0 +1,23 @@ +{ + "info": "A paraxial DOE lens system.", + "surfaces": [ + { + "idx": 1, + "type": "DOE", + "l": 4.0, + "res": [4000, 4000], + "fab_ps": 0.001, + "(d)": 1.2, + "is_square": true, + "param_model": "pixel2d", + "doe_path": null, + "d_next": 50.0 + }, + { + "idx": 2, + "type": "Sensor", + "l": 4.0, + "res": [2000, 2000] + } + ] +} \ No newline at end of file diff --git a/lenses/paraxiallens/doethinlens.json b/lenses/paraxiallens/doethinlens.json new file mode 100644 index 0000000..c4163fc --- /dev/null +++ b/lenses/paraxiallens/doethinlens.json @@ -0,0 +1,40 @@ +{ + "info": "A paraxial hybrid lens system with a diffractive optical element (DOE) and a thin lens.", + "surfaces": [ + { + "idx": 1, + "type": "Aperture", + "r": 2.0, + "(d)": 0.0, + "is_square": false, + "diffraction": false, + "d_next": 2.0 + }, + { + "idx": 2, + "type": "ThinLens", + "foclen": 50.0, + "r": 5.0, + "(d)": 0.2, + "d_next": 2.0 + }, + { + "idx": 3, + "type": "DOE", + "l": 4.0, + "res": [4000, 4000], + "fab_ps": 0.001, + "(d)": 1.2, + "is_square": true, + "param_model": "pixel2d", + "doe_path": null, + "d_next": 48.0 + }, + { + "idx": 4, + "type": "Sensor", + "l": 4.0, + "res": [2000, 2000] + } + ] +} \ No newline at end of file