From 345c5f51c8cd28fb3fd6faf2d3c786af81153ccf Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 18:00:23 +0100 Subject: [PATCH 1/6] adding new metrics --- environment.yml | 2 ++ opt/opt.py | 24 +++++++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/environment.yml b/environment.yml index a8f6790e..9b626767 100644 --- a/environment.yml +++ b/environment.yml @@ -22,6 +22,8 @@ dependencies: - moviepy - matplotlib - scipy>=1.6.0 + - lpips + - torchmetrics - pytorch=1.11.0 - torchvision - cudatoolkit diff --git a/opt/opt.py b/opt/opt.py index a5a287da..ec9abcb3 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -29,6 +29,9 @@ from tqdm import tqdm from typing import NamedTuple, Optional, Union +import lpips +from torchmetrics.functional import structural_similarity_index_measure + device = "cuda" if torch.cuda.is_available() else "cpu" parser = argparse.ArgumentParser() @@ -365,6 +368,7 @@ if args.enable_random: warn("Randomness is enabled for training (normal for LLFF & scenes with background)") +lpips_fn = lpips.LPIPS(net='alex').eval().to(device) epoch_id = -1 while True: dset.shuffle_rays() @@ -376,7 +380,7 @@ def eval_step(): # Put in a function to avoid memory leak print('Eval step') with torch.no_grad(): - stats_test = {'psnr' : 0.0, 'mse' : 0.0} + stats_test = {'psnr' : 0.0, 'ssim' : 0.0, 'lpips' : 0.0, 'mse' : 0.0} # Standard set N_IMGS_TO_EVAL = min(20 if epoch_id > 0 else 5, dset_test.n_images) @@ -431,8 +435,14 @@ def eval_step(): if math.isnan(psnr): print('NAN PSNR', i, img_id, mse_num) assert False + rgb_pred_test_perm = rgb_pred_test.unsqueeze(0).permute(0, 3, 1, 2) + rgb_gt_test_perm = rgb_gt_test.unsqueeze(0).permute(0, 3, 1, 2) + ssim = structural_similarity_index_measure(rgb_pred_test_perm, rgb_gt_test_perm) + lpips = lpips_fn(rgb_pred_test_perm, rgb_gt_test_perm) stats_test['mse'] += mse_num stats_test['psnr'] += psnr + stats_test['ssim'] += ssim + stats_test['lpips'] += lpips n_images_gen += 1 if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE or \ @@ -461,6 +471,8 @@ def eval_step(): stats_test['mse'] /= n_images_gen stats_test['psnr'] /= n_images_gen + stats_test['ssim'] /= n_images_gen + stats_test['lpips'] /= n_images_gen for stat_name in stats_test: summary_writer.add_scalar('test/' + stat_name, stats_test[stat_name], global_step=gstep_id_base) @@ -474,7 +486,7 @@ def eval_step(): def train_step(): print('Train step') pbar = tqdm(enumerate(range(0, epoch_size, args.batch_size)), total=batches_per_epoch) - stats = {"mse" : 0.0, "psnr" : 0.0, "invsqr_mse" : 0.0} + stats = {"mse" : 0.0, "psnr" : 0.0, 'ssim' : 0.0, 'lpips' : 0.0, "invsqr_mse" : 0.0} for iter_id, batch_begin in pbar: gstep_id = iter_id + gstep_id_base if args.lr_fg_begin_step > 0 and gstep_id == args.lr_fg_begin_step: @@ -507,13 +519,19 @@ def train_step(): # Stats mse_num : float = mse.detach().item() psnr = -10.0 * math.log10(mse_num) + rgb_pred_perm = rgb_pred.unsqueeze(0).permute(0, 3, 1, 2) + rgb_gt_perm = rgb_gt.unsqueeze(0).permute(0, 3, 1, 2) + ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm) + lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm) stats['mse'] += mse_num stats['psnr'] += psnr + stats_test['ssim'] += ssim + stats_test['lpips'] += lpips stats['invsqr_mse'] += 1.0 / mse_num ** 2 if (iter_id + 1) % args.print_every == 0: # Print averaged stats - pbar.set_description(f'epoch {epoch_id} psnr={psnr:.2f}') + pbar.set_description(f'epoch {epoch_id} psnr={psnr:.2f} ssim={ssim:.2f} lpips={lpips:.2f}') for stat_name in stats: stat_val = stats[stat_name] / args.print_every summary_writer.add_scalar(stat_name, stat_val, global_step=gstep_id) From 99c069ed1f073c22971de1929617f5dffe1f744a Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 18:03:15 +0100 Subject: [PATCH 2/6] adding new metrics --- opt/opt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/opt/opt.py b/opt/opt.py index ec9abcb3..000f10e7 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -429,14 +429,14 @@ def eval_step(): depth_img, global_step=gstep_id_base, dataformats='HWC') + rgb_pred_test_perm = rgb_pred_test.unsqueeze(0).permute(0, 3, 1, 2) + rgb_gt_test_perm = rgb_gt_test.unsqueeze(0).permute(0, 3, 1, 2) rgb_pred_test = rgb_gt_test = None mse_num : float = all_mses.mean().item() psnr = -10.0 * math.log10(mse_num) if math.isnan(psnr): print('NAN PSNR', i, img_id, mse_num) assert False - rgb_pred_test_perm = rgb_pred_test.unsqueeze(0).permute(0, 3, 1, 2) - rgb_gt_test_perm = rgb_gt_test.unsqueeze(0).permute(0, 3, 1, 2) ssim = structural_similarity_index_measure(rgb_pred_test_perm, rgb_gt_test_perm) lpips = lpips_fn(rgb_pred_test_perm, rgb_gt_test_perm) stats_test['mse'] += mse_num @@ -517,10 +517,10 @@ def train_step(): mse = F.mse_loss(rgb_gt, rgb_pred) # Stats - mse_num : float = mse.detach().item() - psnr = -10.0 * math.log10(mse_num) rgb_pred_perm = rgb_pred.unsqueeze(0).permute(0, 3, 1, 2) rgb_gt_perm = rgb_gt.unsqueeze(0).permute(0, 3, 1, 2) + mse_num : float = mse.detach().item() + psnr = -10.0 * math.log10(mse_num) ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm) lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm) stats['mse'] += mse_num From dfeeef36d22c2aad76c5d3282334a9c8dfd57deb Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 18:12:45 +0100 Subject: [PATCH 3/6] adding new metrics --- opt/opt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/opt/opt.py b/opt/opt.py index 000f10e7..f9133f75 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -438,7 +438,7 @@ def eval_step(): print('NAN PSNR', i, img_id, mse_num) assert False ssim = structural_similarity_index_measure(rgb_pred_test_perm, rgb_gt_test_perm) - lpips = lpips_fn(rgb_pred_test_perm, rgb_gt_test_perm) + lpips = lpips_fn(rgb_pred_test_perm, rgb_gt_test_perm, normalize=True).item() stats_test['mse'] += mse_num stats_test['psnr'] += psnr stats_test['ssim'] += ssim @@ -517,12 +517,12 @@ def train_step(): mse = F.mse_loss(rgb_gt, rgb_pred) # Stats - rgb_pred_perm = rgb_pred.unsqueeze(0).permute(0, 3, 1, 2) - rgb_gt_perm = rgb_gt.unsqueeze(0).permute(0, 3, 1, 2) + rgb_pred_perm = rgb_pred.permute(0, 3, 1, 2) + rgb_gt_perm = rgb_gt.permute(0, 3, 1, 2) mse_num : float = mse.detach().item() psnr = -10.0 * math.log10(mse_num) ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm) - lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm) + lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm, normalize=True).item() stats['mse'] += mse_num stats['psnr'] += psnr stats_test['ssim'] += ssim From 9bc72feca50054960449de1ffa2db86082bc6764 Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 18:15:09 +0100 Subject: [PATCH 4/6] adding new metrics --- opt/opt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/opt/opt.py b/opt/opt.py index f9133f75..fa574e69 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -437,7 +437,7 @@ def eval_step(): if math.isnan(psnr): print('NAN PSNR', i, img_id, mse_num) assert False - ssim = structural_similarity_index_measure(rgb_pred_test_perm, rgb_gt_test_perm) + ssim = structural_similarity_index_measure(rgb_pred_test_perm, rgb_gt_test_perm).item() lpips = lpips_fn(rgb_pred_test_perm, rgb_gt_test_perm, normalize=True).item() stats_test['mse'] += mse_num stats_test['psnr'] += psnr @@ -517,11 +517,12 @@ def train_step(): mse = F.mse_loss(rgb_gt, rgb_pred) # Stats + print(rgb_pred.shape, rgb_gt.shape) rgb_pred_perm = rgb_pred.permute(0, 3, 1, 2) rgb_gt_perm = rgb_gt.permute(0, 3, 1, 2) mse_num : float = mse.detach().item() psnr = -10.0 * math.log10(mse_num) - ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm) + ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm).item() lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm, normalize=True).item() stats['mse'] += mse_num stats['psnr'] += psnr From f86a4786511e25329c1f4021b0fa3705c46f2e86 Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 18:22:59 +0100 Subject: [PATCH 5/6] adding new metrics --- opt/opt.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/opt/opt.py b/opt/opt.py index fa574e69..77be16e5 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -517,22 +517,15 @@ def train_step(): mse = F.mse_loss(rgb_gt, rgb_pred) # Stats - print(rgb_pred.shape, rgb_gt.shape) - rgb_pred_perm = rgb_pred.permute(0, 3, 1, 2) - rgb_gt_perm = rgb_gt.permute(0, 3, 1, 2) mse_num : float = mse.detach().item() psnr = -10.0 * math.log10(mse_num) - ssim = structural_similarity_index_measure(rgb_pred_perm, rgb_gt_perm).item() - lpips = lpips_fn(rgb_pred_perm, rgb_gt_perm, normalize=True).item() stats['mse'] += mse_num stats['psnr'] += psnr - stats_test['ssim'] += ssim - stats_test['lpips'] += lpips stats['invsqr_mse'] += 1.0 / mse_num ** 2 if (iter_id + 1) % args.print_every == 0: # Print averaged stats - pbar.set_description(f'epoch {epoch_id} psnr={psnr:.2f} ssim={ssim:.2f} lpips={lpips:.2f}') + pbar.set_description(f'epoch {epoch_id} psnr={psnr:.2f}') for stat_name in stats: stat_val = stats[stat_name] / args.print_every summary_writer.add_scalar(stat_name, stat_val, global_step=gstep_id) From 5daeb279d2e72730e5c6969cebe1895cb768437e Mon Sep 17 00:00:00 2001 From: Maciej Domaradzki Date: Sat, 4 Mar 2023 22:56:34 +0100 Subject: [PATCH 6/6] adding new metrics --- opt/opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opt/opt.py b/opt/opt.py index 77be16e5..73615d33 100644 --- a/opt/opt.py +++ b/opt/opt.py @@ -486,7 +486,7 @@ def eval_step(): def train_step(): print('Train step') pbar = tqdm(enumerate(range(0, epoch_size, args.batch_size)), total=batches_per_epoch) - stats = {"mse" : 0.0, "psnr" : 0.0, 'ssim' : 0.0, 'lpips' : 0.0, "invsqr_mse" : 0.0} + stats = {"mse" : 0.0, "psnr" : 0.0, "invsqr_mse" : 0.0} for iter_id, batch_begin in pbar: gstep_id = iter_id + gstep_id_base if args.lr_fg_begin_step > 0 and gstep_id == args.lr_fg_begin_step: