diff --git a/1_end2end_5lines.py b/1_end2end_5lines.py index dfcf477..f40f356 100644 --- a/1_end2end_5lines.py +++ b/1_end2end_5lines.py @@ -33,12 +33,6 @@ def config(): with open('configs/end2end_5lines.yml') as f: args = yaml.load(f, Loader=yaml.FullLoader) - # ==> Device - num_gpus = torch.cuda.device_count() - args['num_gpus'] = num_gpus - device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") - args['device'] = device - # ==> Result folder characters = string.ascii_letters + string.digits random_string = ''.join(random.choice(characters) for i in range(4)) @@ -46,18 +40,27 @@ def config(): args['result_dir'] = result_dir os.makedirs(result_dir, exist_ok=True) print(f'Result folder: {result_dir}') - - # ==> Logger - set_logger(result_dir) - logging.info(args) # ==> Random seed set_seed(args['train']['seed']) + # ==> Logger + set_logger(result_dir) + # Log to wandb if not args['DEBUG']: - # wandb init pass - + + # ==> Device + num_gpus = torch.cuda.device_count() + args['num_gpus'] = num_gpus + device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") + logging.info(f'Using {num_gpus} {torch.cuda.get_device_name(0)} GPU(s)') + args['device'] = device + + # ==> Save config + with open(f'{result_dir}/config.yml', 'w') as f: + yaml.dump(args, f) + return args diff --git a/deeplens/optics.py b/deeplens/optics.py index a2e6910..a600ebc 100644 --- a/deeplens/optics.py +++ b/deeplens/optics.py @@ -124,6 +124,9 @@ def post_computation(self): avg_pupilz, avg_pupilx = self.entrance_pupil() self.fnum = self.foclen / avg_pupilx / 2 + if self.r_last < 8.0: + self.is_cellphone = True + def find_aperture(self): """ Find aperture by surfaces previous and next materials. @@ -820,7 +823,8 @@ def render_compute_image(self, img, depth, scale, ray, point_pixel=True, train=T irr_img += img[...,idx_i, idx_j+1] * w_i * (1-w_j) irr_img += img[...,idx_i+1, idx_j+1] * (1-w_i) * (1-w_j) - I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / (torch.sum(ray.ra, -3) + 1e-6) + I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / (torch.sum(ray.ra, -3) + 1e-6) # w/ vignetting correction + # I = (torch.sum(irr_img * ray.ra, -3) + 1e-9) / ray.ra.shape[-3] # w/o vignetting correction # ====> Add sensor noise if noise > 0: @@ -1469,7 +1473,7 @@ def set_target_fov_fnum(self, hfov, fnum, imgh=None): self.foclen = self.calc_efl() aper_r = self.foclen / fnum / 2 - self.surfaces[self.aper_idx].r = aper_r + self.surfaces[self.aper_idx].r = float(aper_r) # --------------------------- @@ -2216,11 +2220,10 @@ def loss_ray_angle(self, target=0.7, depth=DEPTH): def loss_reg(self): """ An empirical regularization loss for lens design. """ - # For spherical lens design - loss_reg = 0.1 * self.loss_infocus() + self.loss_self_intersec(dist_bound=0.5, thickness_bound=0.5) - - # For cellphone lens design, use 0.01 * loss_reg - # loss_reg = 0.1 * self.loss_infocus() + self.loss_ray_angle() + (self.loss_self_intersec() + self.loss_last_surf()) #+ self.loss_surface() + if self.is_cellphone: + loss_reg = 0.1 * self.loss_infocus() + self.loss_ray_angle() + (self.loss_self_intersec() + self.loss_last_surf()) + else: + loss_reg = 0.1 * self.loss_infocus() + self.loss_self_intersec(dist_bound=0.5, thickness_bound=0.5) return loss_reg diff --git a/deeplens/surfaces.py b/deeplens/surfaces.py index b98b41f..187a6f9 100644 --- a/deeplens/surfaces.py +++ b/deeplens/surfaces.py @@ -275,13 +275,22 @@ def surface_sample(self, N=1000): o2 = torch.stack((x2,y2,z2), 1).to(self.device) return o2 + def surface(self, x, y): + """ Calculate z coordinate of the surface at (x, y) with offset. + + This function is used in lens setup plotting. + """ + x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device) + y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device) + return self.sag(x, y) + def surface_with_offset(self, x, y): """ Calculate z coordinate of the surface at (x, y) with offset. This function is used in lens setup plotting. """ - x = torch.tensor(x).to(self.device) if type(x) is float else x - y = torch.tensor(y).to(self.device) if type(y) is float else y + x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device) + y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device) return self.sag(x, y) + self.d def max_height(self): diff --git a/setup.py b/setup.py index 2c8460a..3f196fa 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ 'transformers', 'lpips', 'einops', + 'timm', ], license='Creative Commons Attribution-NonCommercial 4.0 International License', classifiers=[