Skip to content

Commit

Permalink
[new] paraxial optical model
Browse files Browse the repository at this point in the history
  • Loading branch information
singer-yang committed Dec 15, 2024
1 parent 9216077 commit 6762170
Show file tree
Hide file tree
Showing 9 changed files with 603 additions and 258 deletions.
525 changes: 315 additions & 210 deletions deeplens/diffraclens.py

Large diffs are not rendered by default.

141 changes: 126 additions & 15 deletions deeplens/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import torch
from torchvision.utils import make_grid, save_image
from .optics import (
DeepObj,
init_device,
Expand All @@ -16,26 +17,25 @@
GREEN_RESPONSE,
BLUE_RESPONSE,
DEPTH,
render_psf_map,
)


class Lens(DeepObj):
"""Geolens class. A geometric lens consisting of refractive surfaces, simulate with ray tracing. May contain diffractive surfaces, but still use ray tracing to simulate."""

def __init__(self, filename=None, sensor_res=[1024, 1024]):
def __init__(self, filename, sensor_res=[1024, 1024]):
"""A lens class."""
super(Lens, self).__init__()
self.device = init_device()

# Load lens file
self.lens_name = filename
self.load_file(filename, sensor_res)
self.load_file(filename)
self.to(self.device)

# Lens calculation
self.prepare_sensor(sensor_res)
self.post_computation()
# # Lens calculation
# self.prepare_sensor(sensor_res)
# self.post_computation()

def load_file(self, filename):
"""Load lens from a file."""
Expand All @@ -58,62 +58,173 @@ def post_computation(self):
# ===========================================
def psf(self, points, ks=51, wvln=0.589, **kwargs):
"""Compute monochrome point PSF. This function is differentiable."""
pass
raise NotImplementedError

def psf_rgb(self, points, ks=51, **kwargs):
def psf_rgb(self, point, ks=51, **kwargs):
"""Compute RGB point PSF. This function is differentiable."""
psfs = []
for wvln in WAVE_RGB:
psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs))
psfs.append(self.psf(point=point, ks=ks, wvln=wvln, **kwargs))
psf_rgb = torch.stack(psfs, dim=-3) # shape [3, ks, ks] or [N, 3, ks, ks]
return psf_rgb

def psf_narrow_band(self, points, ks=51, **kwargs):
"""Should be migrated to psf_rgb."""
"""Should be migrated to psf_rgb.
In this function we use an average for different wavelengths. Actually we should use the sensor response function.
"""
# Red
psf_r = []
for i, wvln in enumerate(WAVE_RED):
for _, wvln in enumerate(WAVE_RED):
psf_r.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs))
psf_r = torch.stack(psf_r, dim=-3).mean(dim=-3)

# Green
psf_g = []
for i, wvln in enumerate(WAVE_GREEN):
for _, wvln in enumerate(WAVE_GREEN):
psf_g.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs))
psf_g = torch.stack(psf_g, dim=-3).mean(dim=-3)

# Blue
psf_b = []
for i, wvln in enumerate(WAVE_BLUE):
for _, wvln in enumerate(WAVE_BLUE):
psf_b.append(self.psf(points=points, wvln=wvln, ks=ks, **kwargs))
psf_b = torch.stack(psf_b, dim=-3).mean(dim=-3)

# RGB
psf = torch.stack([psf_r, psf_g, psf_b], dim=-3)
return psf

def psf_spectrum(self, points, ks=51, **kwargs):
"""Should be migrated to psf_rgb."""
# Red
psf_r = []
for i, wvln in enumerate(WAVE_BOARD_BAND):
psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs)
psf_r.append(psf * RED_RESPONSE[i])
psf_r = torch.stack(psf_r, dim=0).sum(dim=0) / sum(RED_RESPONSE)

# Green
psf_g = []
for i, wvln in enumerate(WAVE_BOARD_BAND):
psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs)
psf_g.append(psf * GREEN_RESPONSE[i])
psf_g = torch.stack(psf_g, dim=0).sum(dim=0) / sum(GREEN_RESPONSE)

# Blue
psf_b = []
for i, wvln in enumerate(WAVE_BOARD_BAND):
psf = self.psf(points=points, ks=ks, wvln=wvln, **kwargs)
psf_b.append(psf * BLUE_RESPONSE[i])
psf_b = torch.stack(psf_b, dim=0).sum(dim=0) / sum(BLUE_RESPONSE)

# RGB
psf = torch.stack([psf_r, psf_g, psf_b], dim=0) # shape [3, ks, ks]
return psf

def psf_map(self, grid=21, ks=51, depth=-20000.0, **kwargs):
def draw_psf(self, depth=DEPTH, ks=101, save_name="./psf.png"):
"""Draw RGB on-axis PSF."""
psfs = []
for wvln in WAVE_RGB:
psf = self.psf(point=[0, 0, depth], ks=ks, wvln=wvln)
psfs.append(psf)

psfs = torch.stack(psfs, dim=0) # shape [3, ks, ks]
save_image(psfs.unsqueeze(0), save_name, normalize=True)

def point_source_grid(
self, depth, grid=9, normalized=True, quater=False, center=False
):
"""
Generate point source grid for PSF calculation.
"""
# ==> Use center of each patch
if grid == 1:
x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]])
assert not quater, "Quater should be False when grid is 1."
else:
if center:
half_bin_size = 1 / 2 / (grid - 1)
x, y = torch.meshgrid(
torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid),
torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid),
indexing="xy",
)
# ==> Use corner
else:
x, y = torch.meshgrid(
torch.linspace(-0.98, 0.98, grid),
torch.linspace(0.98, -0.98, grid),
indexing="xy",
)

z = torch.full((grid, grid), depth)
point_source = torch.stack([x, y, z], dim=-1)

# ==> Use quater of the sensor plane to save memory
if quater:
z = torch.full((grid, grid), depth)
point_source = torch.stack([x, y, z], dim=-1)
bound_i = grid // 2 if grid % 2 == 0 else grid // 2 + 1
bound_j = grid // 2
point_source = point_source[0:bound_i, bound_j:, :]

if not normalized:
raise Exception("Need to specify the scale.")
scale = self.calc_scale_pinhole(depth)
point_source[..., 0] *= scale * self.sensor_size[0] / 2
point_source[..., 1] *= scale * self.sensor_size[1] / 2

return point_source

def psf_map(self, grid=21, ks=51, depth=-20000.0, wvln=0.589, **kwargs):
"""Compute PSF map."""
pass
# raise NotImplementedError
points = self.point_source_grid(depth=depth, grid=grid, center=True)
points = points.reshape(-1, 3)

psfs = []
for i in range(points.shape[0]):
point = points[i, ...]
psf = self.psf(point=point, ks=ks, wvln=wvln)
psfs.append(psf)

psf_map = torch.stack(psfs).unsqueeze(1)
psf_map = make_grid(psf_map, nrow=grid, padding=0)[0, :, :]

return psf_map

def psf_map_rgb(self, grid=21, ks=51, depth=-20000.0, **kwargs):
"""Compute RGB PSF map."""
psfs = []
for wvln in WAVE_RGB:
psf_map = self.psf_map(grid=grid, ks=ks, depth=depth, wvln=wvln, **kwargs)
psfs.append(psf_map)
psf_map = torch.stack(psfs, dim=0) # shape [3, grid*ks, grid*ks]
return psf_map

def draw_psf_map(
self,
grid=8,
depth=DEPTH,
ks=101,
log_scale=False,
save_name="./psf_map.png",
):
"""Draw RGB PSF map of the DOE thin lens."""
# Calculate PSF map
psf_maps = []
for wvln in WAVE_RGB:
psf_map = self.psf_map(grid=grid, depth=depth, ks=ks, wvln=wvln)
psf_maps.append(psf_map)
psf_map = torch.stack(psf_maps, dim=0) # shape [3, grid*ks, grid*ks]

# Data processing for visualization
if log_scale:
psf_map = torch.log(psf_map + 1e-4)
psf_map = (psf_map - psf_map.min()) / (psf_map.max() - psf_map.min())

save_image(psf_map.unsqueeze(0), save_name, normalize=True)

# ===========================================
# Rendering-ralated functions
Expand Down
10 changes: 8 additions & 2 deletions deeplens/optics/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def __init__(self, name=None):
self.name = "vacuum" if name is None else name.lower()
self.load_dispersion()

def get_name(self):
if self.dispersion == "optimizable":
return f"{self.n.item():.4f}/{self.V.item():.2f}"
else:
return self.name

def load_dispersion(self):
"""Load material dispersion equation."""
if self.name in SELLMEIER_TABLE:
Expand Down Expand Up @@ -115,7 +121,7 @@ def match_material(self):

self.load_dispersion()

def get_optimizer_params(self, lr=1e-3):
def get_optimizer_params(self, lr=[1e-4, 1e-2]):
"""Optimize the material parameters (n, V)."""
self.n = torch.tensor(self.n).to(self.device)
self.V = torch.tensor(self.V).to(self.device)
Expand All @@ -124,7 +130,7 @@ def get_optimizer_params(self, lr=1e-3):
self.V.requires_grad = True
self.dispersion = "optimizable"

params = {"params": [self.A, self.B], "lr": lr}
params = [{"params": [self.n], "lr": lr[0]}, {"params": [self.V], "lr": lr[1]}]
return params


Expand Down
30 changes: 15 additions & 15 deletions deeplens/optics/surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def surf_dict(self):
"r": float(f"{self.r:.6f}"),
"(d)": float(f"{self.d.item():.3f}"),
"is_square": self.is_square,
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

return surf_dict
Expand Down Expand Up @@ -800,7 +800,7 @@ def get_optimizer_params(
f"params.append({{'params': [self.ai{2*i}], 'lr': lr[3] * decay**{i-1}}})"
)

if optim_mat and self.mat2.name != "air":
if optim_mat and self.mat2.get_name() != "air":
params += self.mat2.get_optimizer_params()

return params
Expand Down Expand Up @@ -836,7 +836,7 @@ def surf_dict(self):
"(d)": float(f"{self.d.item():.3f}"),
"k": float(f"{self.k.item():.6f}"),
"ai": [],
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}
for i in range(1, self.ai_degree + 1):
exec(f"surf_dict['(ai{2*i})'] = self.ai{2*i}.item()")
Expand All @@ -852,7 +852,7 @@ def zmx_str(self, surf_idx, d_next):
assert (
self.ai is not None or self.k != 0
), "Spheric surface is re-implemented in Spheric class."
if self.mat2.name == "air":
if self.mat2.get_name() == "air":
zmx_str = f"""SURF {surf_idx}
TYPE EVENASPH
CURV {self.c.item()}
Expand All @@ -870,7 +870,7 @@ def zmx_str(self, surf_idx, d_next):
TYPE EVENASPH
CURV {self.c.item()}
DISZ {d_next.item()}
GLAS {self.mat2.name.upper()} 0 0 {self.mat2.n} {self.mat2.V}
GLAS {self.mat2.get_name().upper()} 0 0 {self.mat2.n} {self.mat2.V}
DIAM {self.r * 2}
PARM 1 {self.ai2.item()}
PARM 2 {self.ai4.item()}
Expand Down Expand Up @@ -989,7 +989,7 @@ def get_optimizer_params(self, lr, optim_mat=False):
else:
raise Exception("Unsupported cubic degree!")

if optim_mat and self.mat2.name != "air":
if optim_mat and self.mat2.get_name() != "air":
params += self.mat2.get_optimizer_params()

return params
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def surf_dict(self):
"param_model": self.param_model,
"f0": self.f0.item(),
"(d)": float(f"{self.d.item():.3f}"),
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

elif self.param_model == "binary2":
Expand All @@ -1442,7 +1442,7 @@ def surf_dict(self):
"order6": self.order6.item(),
"order8": self.order8.item(),
"(d)": f"{float(self.d.item()):.3f}",
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

elif self.param_model == "poly1d":
Expand All @@ -1458,7 +1458,7 @@ def surf_dict(self):
"order6": self.order6.item(),
"order7": self.order7.item(),
"(d)": float(f"{self.d.item():.3f}"),
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

elif self.param_model == "grating":
Expand All @@ -1470,7 +1470,7 @@ def surf_dict(self):
"theta": self.theta.item(),
"alpha": self.alpha.item(),
"(d)": float(f"{self.d.item():.3f}"),
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

return surf_dict
Expand Down Expand Up @@ -1575,7 +1575,7 @@ def surf_dict(self):
"l": self.l,
"(d)": float(f"{self.d.item():.3f}"),
"is_square": True,
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

return surf_dict
Expand Down Expand Up @@ -1642,7 +1642,7 @@ def get_optimizer_params(self, lr=[0.001, 0.001], optim_mat=False):
params.append({"params": [self.c], "lr": lr[0]})
params.append({"params": [self.d], "lr": lr[1]})

if optim_mat and self.mat2.name != "air":
if optim_mat and self.mat2.get_name() != "air":
params += self.mat2.get_optimizer_params()

return params
Expand All @@ -1656,14 +1656,14 @@ def surf_dict(self):
"(c)": float(f"{self.c.item():.3f}"),
"roc": float(f"{roc:.3f}"),
"(d)": float(f"{self.d.item():.3f}"),
"mat2": self.mat2.name,
"mat2": self.mat2.get_name(),
}

return surf_dict

def zmx_str(self, surf_idx, d_next):
"""Return Zemax surface string."""
if self.mat2.name == "air":
if self.mat2.get_name() == "air":
zmx_str = f"""SURF {surf_idx}
TYPE STANDARD
CURV {self.c.item()}
Expand All @@ -1675,7 +1675,7 @@ def zmx_str(self, surf_idx, d_next):
TYPE STANDARD
CURV {self.c.item()}
DISZ {d_next.item()}
GLAS {self.mat2.name.upper()} 0 0 {self.mat2.n} {self.mat2.V}
GLAS {self.mat2.get_name().upper()} 0 0 {self.mat2.n} {self.mat2.V}
DIAM {self.r*2}
"""
return zmx_str
Expand Down
Loading

0 comments on commit 6762170

Please sign in to comment.