From 4d690564022ba9bc161b18305bee3554c95c830f Mon Sep 17 00:00:00 2001 From: kocchop Date: Wed, 25 Jan 2023 22:49:30 -0500 Subject: [PATCH] initial commit --- .gitignore | 15 ++ datasets.py | 245 +++++++++++++++++++++ models/__init__.py | 0 models/generator_models.py | 369 ++++++++++++++++++++++++++++++++ models/models.py | 422 +++++++++++++++++++++++++++++++++++++ train.py | 395 ++++++++++++++++++++++++++++++++++ utils.py | 138 ++++++++++++ validate.py | 295 ++++++++++++++++++++++++++ 8 files changed, 1879 insertions(+) create mode 100644 .gitignore create mode 100644 datasets.py create mode 100644 models/__init__.py create mode 100644 models/generator_models.py create mode 100644 models/models.py create mode 100644 train.py create mode 100644 utils.py create mode 100644 validate.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f39bc49 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +#temp files +.nfs* + +#cache files +__pycache__ +*.pyc + +#saved models +*.pth + +#log dirs +logdir/* + +#Checkpoint dir +chkpts/* diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..4d328ea --- /dev/null +++ b/datasets.py @@ -0,0 +1,245 @@ +""" +This is sample dataloader script for robust multimodal fusion GAN +This dataloader script is for nyu_v2 dataset where +The sparse depth and ground truth depths are stored as h5 file, and +The rgb image is stored as a png +""" +import glob +import random +import os +import numpy as np + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from utils import * + +import cv2 +import h5py +from PIL import Image +import torchvision.transforms as transforms + +def read_gt_depth(path): + + file = h5py.File(path, "r") + gt_depth = np.array(file['depth_gt']) + + return gt_depth + +def read_sparse_depth(path): + + file = h5py.File(path, "r") + sparse_depth = np.array(file['lidar']) + + return sparse_depth + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = os.path.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, + filename_tmpl): + """Generate paired paths from an meta information file. + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + Example of an meta information file: + ``` + 0001.png (228,304,1) + 0002.png (228,304,1) + ``` + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder, rgb_foldar]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt', 'rgb']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 3, ( + 'The len of folders should be 3 with [input_folder, gt_folder, rgb_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 3, ( + 'The len of keys should be 2 with [input_key, gt_key, rgb_key]. ' + f'But got {len(keys)}') + input_folder, gt_folder, rgb_folder = folders + input_key, gt_key, rgb_key = keys + + + with open(meta_info_file, 'r') as fin: + gt_names = [line.split(' ')[0] for line in fin] + + paths = [] + + rgb_ext = '.png' + depth_ext = '.h5' + + for basename in gt_names: + input_name = f'{filename_tmpl.format(basename)}{depth_ext}' + rgb_name = f'{filename_tmpl.format(basename)}{rgb_ext}' + gt_name = f'{filename_tmpl.format(basename)}{depth_ext}' + + input_path = os.path.join(input_folder, input_name) + rgb_path = os.path.join(rgb_folder, rgb_name) + gt_path = os.path.join(gt_folder, gt_name) + + paths.append( + dict([(f'{input_key}_path', input_path), + (f'{gt_key}_path', gt_path), + (f'{rgb_key}_path', rgb_path)])) + return paths + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt', 'rgb']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 3, ( + 'The len of folders should be 3 with [input_folder, gt_folder, rgb_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 3, ( + 'The len of keys should be 3 with [input_key, gt_key, rgb_key]. ' + f'But got {len(keys)}') + input_folder, gt_folder, rgb_folder = folders + input_key, gt_key, rgb_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + rgb_paths = list(scandir(rgb_folder)) + assert len(input_paths) == len(gt_paths), ( + f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + assert len(input_paths) == len(rgb_paths), ( + f'{input_key} and {rgb_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(rgb_paths)}.') + paths = [] + + rgb_ext = '.png' + + for gt_path in gt_paths: + basename, ext = os.path.splitext(os.path.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + rgb_name = f'{filename_tmpl.format(basename)}{rgb_ext}' + + input_path = os.path.join(input_folder, input_name) + gt_path = os.path.join(gt_folder, gt_path) + rgb_path = os.path.join(rgb_folder, rgb_path) + + paths.append( + dict([(f'{input_key}_path', input_path), + (f'{gt_key}_path', gt_path), + (f'{rgb_key}_path', rgb_path)])) + return paths + + +class PairedImageDataset(Dataset): + + def __init__(self, root, opt, hr_shape): + #We cannot use torch.Transforms because transforms.ToTensor() normalizes the image assuming its a 3 channel uint8 RGB image + super(PairedImageDataset, self).__init__() + + self.opt = opt + + # assumption is that the sparse depth is in "lidar" folder + # ground truth depth is in "depth_gt" folder + # and rgb image is in "image_rgb" folder + self.gt_folder, self.lq_folder, self.rgb_folder = os.path.join(root,'depth_gt'), os.path.join(root,'lidar_5p'), os.path.join(root,'image_rgb') + + self.filename_tmpl = '{}' + + self.transform_rgb = transforms.Compose([transforms.Pad((0,6,0,6),fill=0), + transforms.ToTensor(), + transforms.Normalize(mean = rgb_mean, + std = rgb_std), + ]) + + if self.opt.meta_info_file is not None: + self.meta_file = os.path.join(root, self.opt.meta_info_file) + self.paths = paired_paths_from_meta_info_file( + [self.lq_folder, self.gt_folder, self.rgb_folder], ['lq', 'gt', 'rgb'], + self.meta_file, self.filename_tmpl) + else: + self.paths = paired_paths_from_folder( + [self.lq_folder, self.gt_folder, self.rgb_folder], ['lq', 'gt', 'rgb'], + self.filename_tmpl) + + def __getitem__(self, index): + + # Load gt and lq depths. Dimension order: HW; channel: Grayscale; + # Depth range: [0, 9.999], float32. + gt_path = self.paths[index]['gt_path'] + img_hi = read_gt_depth(gt_path) + temp_hi = torch.from_numpy(img_hi) + img_hi = F.pad(temp_hi,(0,0,6,6),'constant',0) + + lq_path = self.paths[index]['lq_path'] + img_lo = read_sparse_depth(lq_path) + temp_lo = torch.from_numpy(img_lo) + img_lo = F.pad(temp_lo,(0,0,6,6),'constant',0) + + rgb_path = self.paths[index]['rgb_path'] + img_color = Image.open(rgb_path) + + # depth transformation + gt = (img_hi-depth_mean)/depth_std + sparse = (img_lo-sparse_mean)/sparse_std + + # RGB transformation + img_rgb = self.transform_rgb(img_color) + + return { + 'sparse': sparse, + 'gt': gt, + 'rgb': img_rgb + } + + def __len__(self): + return len(self.paths) + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/generator_models.py b/models/generator_models.py new file mode 100644 index 0000000..6b466a1 --- /dev/null +++ b/models/generator_models.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +from models.models import * + +class nyu_modelA(nn.Module): + """Simplified implementation of the Vision transformer. + + Parameters + ---------- + img_size : int Tuple + Enter the Height and Width (it is not a square). + + patch_size : int + Both height and the width of the patch (it is a square). + + in_chans : int + Number of input channels. + + n_classes : int + Number of classes. + + embed_dim : int + Dimensionality of the token/patch embeddings. + + depth : int + Number of blocks. + + n_heads : int + Number of attention heads. + + mlp_ratio : float + Determines the hidden dimension of the `MLP` module. + + qkv_bias : bool + If True then we include bias to the query, key and value projections. + + p, attn_p : float + Dropout probability. + + Attributes + ---------- + patch_embed : PatchEmbed + Instance of `PatchEmbed` layer. + + cls_token : nn.Parameter + Learnable parameter that will represent the first token in the sequence. + It has `embed_dim` elements. + + pos_emb : nn.Parameter + Positional embedding of the cls token + all the patches. + It has `(n_patches + 1) * embed_dim` elements. + + pos_drop : nn.Dropout + Dropout layer. + + blocks : nn.ModuleList + List of `Block` modules. + + norm : nn.LayerNorm + Layer normalization. + """ + def __init__( + self, + filters=64, + num_res_blocks=1, + img_size=(240,304), + patch_size=16, + in_chans=3, + out_chans = 1, + n_classes=1000, + embed_dim=768, + depth=12, + n_heads=12, + mlp_ratio=4., + qkv_bias=True, + p=0., + attn_p=0., + ): + super().__init__() + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.n_patches, embed_dim) + ) + + self.pos_drop = nn.Dropout(p=p) + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + n_heads=n_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + p=p, + attn_p=attn_p, + ) + for _ in range(depth) + ] + ) + + self.norm = nn.LayerNorm(embed_dim, eps=1e-6) + # self.head = nn.Linear(embed_dim, n_classes) + + self.token_fold = nn.Fold(output_size = img_size, kernel_size = patch_size, stride = patch_size) + + self.conv_skip_lidar = nn.Sequential( + nn.Conv2d(out_chans, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + # nn.GELU(), + ) + + self.conv_skip_rgb = nn.Sequential( + nn.Conv2d(in_chans, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + # nn.GELU(), + ) + + self.conv_fold = nn.Sequential( + nn.Conv2d(3, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + # nn.GELU(), + ) + + self.RRDB = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)]) + + # Final output block + self.conv = nn.Sequential( + nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + # nn.GELU(), + nn.Conv2d(filters, out_chans, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, rgb, lidar): + """Run the forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, in_chans, img_size, img_size)`. + + Returns + ------- + logits : torch.Tensor + Logits over all the classes - `(n_samples, n_classes)`. + """ + n_samples = rgb.shape[0] + x = self.patch_embed(rgb) + + x = x + self.pos_embed # (n_samples, 1 + n_patches, embed_dim) + x = self.pos_drop(x) + + for block in self.blocks: + x = block(x) + + x = self.norm(x) + + x = self.token_fold(x.transpose(1,2)) + + x = self.conv_fold(x) + self.conv_skip_lidar(lidar) + self.conv_skip_rgb(rgb) + + x = self.RRDB(x) + + x = self.conv(x) + + return x + +class nyu_modelB(nn.Module): + """Simplified implementation of the Vision transformer. + + Parameters + ---------- + img_size : int Tuple + Enter the Height and Width (it is not a square). + + patch_size : int + Both height and the width of the patch (it is a square). + + in_chans : int + Number of input channels. + + n_classes : int + Number of classes. + + embed_dim : int + Dimensionality of the token/patch embeddings. + + depth : int + Number of blocks. + + n_heads : int + Number of attention heads. + + mlp_ratio : float + Determines the hidden dimension of the `MLP` module. + + qkv_bias : bool + If True then we include bias to the query, key and value projections. + + p, attn_p : float + Dropout probability. + + Attributes + ---------- + patch_embed : PatchEmbed + Instance of `PatchEmbed` layer. + + cls_token : nn.Parameter + Learnable parameter that will represent the first token in the sequence. + It has `embed_dim` elements. + + pos_emb : nn.Parameter + Positional embedding of the cls token + all the patches. + It has `(n_patches + 1) * embed_dim` elements. + + pos_drop : nn.Dropout + Dropout layer. + + blocks : nn.ModuleList + List of `Block` modules. + + norm : nn.LayerNorm + Layer normalization. + """ + def __init__( + self, + filters=64, + num_res_blocks=1, + img_size=(240,304), + patch_size=16, + rgb_chans=3, + lidar_chans=1, + out_chans = 1, + n_classes=1000, + embed_dim=768, + depth=12, + n_heads=12, + mlp_ratio=4., + qkv_bias=True, + p=0., + attn_p=0., + ): + super().__init__() + + self.patch_embed_rgb = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=rgb_chans, + embed_dim=embed_dim, + ) + + self.patch_embed_lidar = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=lidar_chans, + embed_dim=embed_dim, + ) + + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed_lidar.n_patches + self.patch_embed_rgb.n_patches, embed_dim) + ) + + self.pos_drop = nn.Dropout(p=p) + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + n_heads=n_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + p=p, + attn_p=attn_p, + ) + for _ in range(depth) + ] + ) + + self.norm = nn.LayerNorm(embed_dim, eps=1e-6) + # self.head = nn.Linear(embed_dim, n_classes) + + self.token_fold = nn.Fold(output_size = img_size, kernel_size = patch_size, stride = patch_size) + + self.conv_skip_lidar = nn.Sequential( + nn.Conv2d(lidar_chans, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.conv_fusion_lidar = nn.Sequential( + nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.conv_skip_rgb = nn.Sequential( + nn.Conv2d(rgb_chans, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.conv_fusion_rgb = nn.Sequential( + nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.conv_fold_rgb = nn.Sequential( + nn.Conv2d(3, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.conv_fold_lidar = nn.Sequential( + nn.Conv2d(3, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + ) + + self.RRDB = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)]) + + # Final output block + self.conv = nn.Sequential( + nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(filters, out_chans, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, rgb, lidar): + """Run the forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, in_chans, img_size, img_size)`. + + Returns + ------- + logits : torch.Tensor + Logits over all the classes - `(n_samples, n_classes)`. + """ + n_samples = rgb.shape[0] + x_rgb = self.patch_embed_rgb(rgb) + x_lidar = self.patch_embed_lidar(lidar) + + x = torch.cat((x_rgb, x_lidar), dim=1) + + x = x + self.pos_embed # (n_samples, 1 + n_patches, embed_dim) + x = self.pos_drop(x) + + for block in self.blocks: + x = block(x) + + x = self.norm(x) + + x_rgb = self.token_fold(x[:,:-285,:].transpose(1,2)) + x_lidar = self.token_fold(x[:,285:,:].transpose(1,2)) + + x_rgb = self.conv_fold_rgb(x_rgb) + self.conv_skip_rgb(rgb) + x_lidar = self.conv_fold_lidar(x_lidar) + self.conv_skip_lidar(lidar) + + x = self.conv_fusion_rgb(x_rgb) + self.conv_fusion_lidar(x_lidar) + + x = self.RRDB(x) + + x = self.conv(x) + + return x \ No newline at end of file diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000..8cb0951 --- /dev/null +++ b/models/models.py @@ -0,0 +1,422 @@ +""" +Model and loss description file +""" + +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np +import math + +def imgrad(img): + img = torch.mean(img, 1, True) + fx = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]) + conv1 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) + weight = torch.from_numpy(fx).float().unsqueeze(0).unsqueeze(0) + + weight = weight.cuda() + conv1.weight = nn.Parameter(weight) + grad_x = conv1(img) + + fy = np.array([[1,2,1],[0,0,0],[-1,-2,-1]]) + conv2 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) + weight = torch.from_numpy(fy).float().unsqueeze(0).unsqueeze(0) + + weight = weight.cuda() + conv2.weight = nn.Parameter(weight) + grad_y = conv2(img) + + return grad_y, grad_x + +def imgrad_yx(img): + N,C,_,_ = img.size() + grad_y, grad_x = imgrad(img) + return torch.cat((grad_y.view(N,C,-1), grad_x.view(N,C,-1)), dim=1) + + +class NormalLoss(nn.Module): + def __init__(self): + super(NormalLoss, self).__init__() + + def forward(self, grad_fake, grad_real): + prod = ( grad_fake[:,:,None,:] @ grad_real[:,:,:,None] ).squeeze(-1).squeeze(-1) + fake_norm = torch.sqrt( torch.sum( grad_fake**2, dim=-1 ) ) + real_norm = torch.sqrt( torch.sum( grad_real**2, dim=-1 ) ) + + return 1 - torch.mean( prod/(fake_norm*real_norm) ) + + +class DenseResidualBlock(nn.Module): + """ + The Dense Residual Block + """ + + def __init__(self, filters, res_scale=0.2): + super(DenseResidualBlock, self).__init__() + self.res_scale = res_scale + + def block(in_features, non_linearity=True): + layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)] + if non_linearity: + layers += [nn.LeakyReLU()] + return nn.Sequential(*layers) + + self.b1 = block(in_features=1 * filters) + self.b2 = block(in_features=2 * filters) + self.b3 = block(in_features=3 * filters) + self.b4 = block(in_features=4 * filters) + self.b5 = block(in_features=5 * filters, non_linearity=False) + + def forward(self, x): + inputs = x + + out = self.b1(inputs) + inputs = torch.cat([inputs, out], 1) + + out = self.b2(inputs) + inputs = torch.cat([inputs, out], 1) + + out = self.b3(inputs) + inputs = torch.cat([inputs, out], 1) + + out = self.b4(inputs) + inputs = torch.cat([inputs, out], 1) + + out = self.b5(inputs) + + + return out.mul(self.res_scale) + x + + +class ResidualInResidualDenseBlock(nn.Module): + def __init__(self, filters, res_scale=0.2): + super(ResidualInResidualDenseBlock, self).__init__() + self.res_scale = res_scale + self.dense_blocks = nn.Sequential( + DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters) + ) + + def forward(self, x): + return self.dense_blocks(x).mul(self.res_scale) + x + + +# Visual Transformer basic layers + +class PatchEmbed(nn.Module): + """Split image into patches and then embed them. + + Parameters + ---------- + img_size : int + Size of the image (it is a square). + + patch_size : int + Size of the patch (it is a square). + + in_chans : int + Number of input channels. + + embed_dim : int + The emmbedding dimension. + + Attributes + ---------- + n_patches : int + Number of patches inside of our image. + + proj : nn.Conv2d + Convolutional layer that does both the splitting into patches + and their embedding. + """ + def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.n_patches = int((img_size[0] * img_size[1]) / (patch_size ** 2)) + + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + ) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, in_chans, img_size, img_size)`. + + Returns + ------- + torch.Tensor + Shape `(n_samples, n_patches, embed_dim)`. + """ + x = self.proj( + x + ) # (n_samples, embed_dim, img_size[0]/patch_size, img_size[0]/patch_size) + x = x.flatten(2) # (n_samples, embed_dim, n_patches) + x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim) + + return x + + +class Attention(nn.Module): + """Attention mechanism. + + Parameters + ---------- + dim : int + The input and out dimension of per token features. + + n_heads : int + Number of attention heads. + + qkv_bias : bool + If True then we include bias to the query, key and value projections. + + attn_p : float + Dropout probability applied to the query, key and value tensors. + + proj_p : float + Dropout probability applied to the output tensor. + + + Attributes + ---------- + scale : float + Normalizing consant for the dot product. + + qkv : nn.Linear + Linear projection for the query, key and value. + + proj : nn.Linear + Linear mapping that takes in the concatenated output of all attention + heads and maps it into a new space. + + attn_drop, proj_drop : nn.Dropout + Dropout layers. + """ + def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.head_dim = dim // n_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_p) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_p) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, n_patches + 1, dim)`. + + Returns + ------- + torch.Tensor + Shape `(n_samples, n_patches + 1, dim)`. + """ + n_samples, n_tokens, dim = x.shape + + if dim != self.dim: + raise ValueError + + qkv = self.qkv(x) # (n_samples, n_patches + 1, 3 * dim) + qkv = qkv.reshape( + n_samples, n_tokens, 3, self.n_heads, self.head_dim + ) # (n_smaples, n_patches + 1, 3, n_heads, head_dim) + qkv = qkv.permute( + 2, 0, 3, 1, 4 + ) # (3, n_samples, n_heads, n_patches + 1, head_dim) + + q, k, v = qkv[0], qkv[1], qkv[2] + k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches + 1) + dp = ( + q @ k_t + ) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1) + attn = dp.softmax(dim=-1) # (n_samples, n_heads, n_patches + 1, n_patches + 1) + attn = self.attn_drop(attn) + + weighted_avg = attn @ v # (n_samples, n_heads, n_patches +1, head_dim) + weighted_avg = weighted_avg.transpose( + 1, 2 + ) # (n_samples, n_patches + 1, n_heads, head_dim) + weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches + 1, dim) + + x = self.proj(weighted_avg) # (n_samples, n_patches + 1, dim) + x = self.proj_drop(x) # (n_samples, n_patches + 1, dim) + + return x + + +class MLP(nn.Module): + """Multilayer perceptron. + + Parameters + ---------- + in_features : int + Number of input features. + + hidden_features : int + Number of nodes in the hidden layer. + + out_features : int + Number of output features. + + p : float + Dropout probability. + + Attributes + ---------- + fc : nn.Linear + The First linear layer. + + act : nn.GELU + GELU activation function. + + fc2 : nn.Linear + The second linear layer. + + drop : nn.Dropout + Dropout layer. + """ + def __init__(self, in_features, hidden_features, out_features, p=0.): + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(p) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, n_patches + 1, in_features)`. + + Returns + ------- + torch.Tensor + Shape `(n_samples, n_patches +1, out_features)` + """ + x = self.fc1( + x + ) # (n_samples, n_patches + 1, hidden_features) + x = self.act(x) # (n_samples, n_patches + 1, hidden_features) + x = self.drop(x) # (n_samples, n_patches + 1, hidden_features) + x = self.fc2(x) # (n_samples, n_patches + 1, out_features) + x = self.drop(x) # (n_samples, n_patches + 1, out_features) + + return x + + +class Block(nn.Module): + """Transformer block. + + Parameters + ---------- + dim : int + Embeddinig dimension. + + n_heads : int + Number of attention heads. + + mlp_ratio : float + Determines the hidden dimension size of the `MLP` module with respect + to `dim`. + + qkv_bias : bool + If True then we include bias to the query, key and value projections. + + p, attn_p : float + Dropout probability. + + Attributes + ---------- + norm1, norm2 : LayerNorm + Layer normalization. + + attn : Attention + Attention module. + + mlp : MLP + MLP module. + """ + def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = Attention( + dim, + n_heads=n_heads, + qkv_bias=qkv_bias, + attn_p=attn_p, + proj_p=p + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + hidden_features = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=hidden_features, + out_features=dim, + ) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Shape `(n_samples, n_patches + 1, dim)`. + + Returns + ------- + torch.Tensor + Shape `(n_samples, n_patches + 1, dim)`. + """ + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + + return x + + +class Discriminator(nn.Module): + def __init__(self, input_shape): + super(Discriminator, self).__init__() + + self.input_shape = input_shape + in_channels, in_height, in_width = self.input_shape + self.output_shape = (1, 15, 19) + + def discriminator_block(in_filters, out_filters, first_block=False): + layers = [] + layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) + if not first_block: + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1)) + layers.append(nn.BatchNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + layers = [] + in_filters = in_channels + for i, out_filters in enumerate([64, 128, 256, 512]): + layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0))) + in_filters = out_filters + + layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1)) + + self.model = nn.Sequential(*layers) + + def forward(self, img): + return self.model(img) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..b880e34 --- /dev/null +++ b/train.py @@ -0,0 +1,395 @@ +""" +This is the codebase for Robust Multimodal Fusion GAN paper titled +"Robust Multimodal Depth Estimation using Transformer based Generative Adversarial Networks" +https://dl.acm.org/doi/abs/10.1145/3503161.3548418 + +Training file for the robust-multimodal-fusion-gan +In order to invoke type: + +python train.py --model nyu_modelA --gpus=0,1 --batch_size=40 --n_epochs=27 --decay_epoch=15 --lr_gap=3 -p chkpts/nyu_modelA.pth -n nyu_modelA_train + +1. -n --> give a name to the run +2. Modify the val dataloader path with appropriate data directory +3. Typically the directory has the following structure + ----|->data.ShapeNetDepth| + |->train| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + |->val| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + |->sample| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + +4. The "depth_gt" and "lidar" are the folders containing dense and sparse depth respectively +5. The meta_info.txt contains the file names of these folders. Refer to misc/ folder for sample meta_info file +6. The folder "sample" contains a few sparse samples. This is to track the model learning visually. +""" + +import argparse +import os +import numpy as np +import math +import itertools +import sys +import time + +import torchvision.transforms as transforms + +from torch.utils.data import DataLoader +from torch.autograd import Variable + +from models.generator_models import * +from models.models import * +from datasets import * +from utils import * +from validate import validate + +import torch.nn as nn +import torch.nn.functional as F +import torch + +from torchsummary import summary +from torch.utils.tensorboard import SummaryWriter + +import torch.optim.lr_scheduler as lr_scheduler + +LOGDIR = "./logdir/" + +def getOpt(): + + parser = argparse.ArgumentParser() + parser.add_argument("--resume_epoch", type=int, default=0, help="epoch to start training from") + parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training") + parser.add_argument("--dataset", type=str, default="nyu_v2", help="name of the dataset (shapeNet or nyu_v2)") + parser.add_argument("--model", type=str, default="nyu_modelA", required = True, help="name of the model (nyu_modelA | nyu_modelB)") + parser.add_argument("--dataset_path", type=str, default="/home/mdl/mzk591/dataset/data.nyuv2/disk3/", help="path to the dataset") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--save_size", type=int, default=8, help="batch size for saved outputs") + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") + parser.add_argument("--decay_epoch", type=int, default=15, help="epoch from which to start lr decay") + parser.add_argument("--lr_gap", type=int, default=4, help="gradient of decay_epoch") + parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") + parser.add_argument("--hr_width", type=int, default=304, help="dense depth width") + parser.add_argument("--channels", type=int, default=1, help="depth image has only 1 channel") + parser.add_argument("--sample_interval", type=int, default=20, help="interval between saving image samples") + parser.add_argument("--warmup_batches", type=int, default=15, help="number of batches with pixel-wise loss only") + parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight") + parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight") + parser.add_argument("--gpus", metavar='DEV_ID', default=None, + help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') + parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') + parser.add_argument('--meta_info_file', '-m', metavar='DIR', default="meta_info.txt", help='Meta file name') + parser.add_argument("--checkpoint_model_path", type=str, required=False, help="Path to checkpoint model") + parser.add_argument("--pretrained_model_path", '-p', type=str, required=False, help="Path to pretrained model") + + return parser.parse_args() + +def main(): + + # setting higher values initially + best_rmse = 9999 + best_rel = 9999 + + opt = getOpt() + + # create the logdir if it does not exists + os.makedirs(LOGDIR, exist_ok=True) + + # create addition log directories + val_image_save_path = os.path.join(LOGDIR,opt.name,"val_images") + saved_model_path = os.path.join(LOGDIR,opt.name,"saved_models") + log_file_name = os.path.join(LOGDIR,opt.name,'%s.log'%opt.name) + tensorboard_save_path = os.path.join(LOGDIR,opt.name) + + # os.makedirs(train_image_save_path, exist_ok=True) + os.makedirs(val_image_save_path, exist_ok=True) + os.makedirs(saved_model_path, exist_ok=True) + + # Create a logger + logger = createLogger(log_file_name) + + # print(opt) + logger.info(opt) + + # initiate tensorboard logger + writer = SummaryWriter(log_dir=tensorboard_save_path) + + if opt.gpus is not None: + try: + opt.gpus = [int(s) for s in opt.gpus.split(',')] + except ValueError: + logger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') + exit(1) + available_gpus = torch.cuda.device_count() + for dev_id in opt.gpus: + if dev_id >= available_gpus: + logger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' + .format(dev_id, available_gpus)) + exit(1) + # Set default device in case the first one on the list != 0 + torch.cuda.set_device(opt.gpus[0]) + + if 'shapeNet' in opt.model: + hr_shape = (192, 256) + elif "nyu" in opt.model: + hr_shape = (240, 304) + + model_config = { + "img_size": hr_shape, + "rgb_chans": 3, + "lidar_chans": 1, + "patch_size": 16, + "embed_dim": 768, + "depth": 12, + "n_heads": 12, + "qkv_bias": True, + "mlp_ratio": 4, + } + + # Initialize generator and discriminator + try: + generator = eval(opt.model)(**model_config) + except: + print("Please select model from: nyu_modelA | nyu_modelB") + quit() + + if opt.resume_epoch == 0 and opt.pretrained_model_path: + generator.load_state_dict(torch.load(opt.pretrained_model_path)) + generator = nn.DataParallel(generator, device_ids = opt.gpus) + generator.cuda() + + discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)) + discriminator = nn.DataParallel(discriminator, device_ids = opt.gpus) + discriminator.cuda() + + # Losses + criterion_GAN = torch.nn.BCEWithLogitsLoss().cuda() + criterion_content = NormalLoss().cuda() + criterion_pixel = torch.nn.L1Loss().cuda() + + if opt.resume_epoch != 0: + # Load pretrained models + saved_generator_chkpt = os.path.join(opt.checkpoint_model_path,"generator_%d.pth" % (opt.resume_epoch-1)) + # saved_generator_chkpt = os.path.join(opt.checkpoint_model_path,"generator_best.pth") + generator.load_state_dict(torch.load(saved_generator_chkpt)) + saved_discriminator_chkpt = os.path.join(opt.checkpoint_model_path,"discriminator_%d.pth" % (opt.resume_epoch-1)) + # saved_discriminator_chkpt = os.path.join(opt.checkpoint_model_path,"discriminator_best.pth") + discriminator.load_state_dict(torch.load(saved_discriminator_chkpt)) + logger.info("Loaded Checkpoint model from epoch %d"%(opt.resume_epoch-1)) + + # Optimizers + optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor + + train_path = os.path.join(opt.dataset_path, "train") + val_path = os.path.join(opt.dataset_path, "val") + + ## Need to use PairedImageDataset Dataset class + train_dataloader = DataLoader( + PairedImageDataset(train_path, opt, hr_shape=hr_shape), + batch_size=opt.batch_size, + shuffle=True, + num_workers=opt.n_cpu, + ) + + val_dataloader = DataLoader( + PairedImageDataset(val_path, opt, hr_shape=hr_shape), + batch_size=opt.save_size, + num_workers=opt.n_cpu, + ) + + # learning rate modification steps + milestones = [opt.decay_epoch, opt.decay_epoch + opt.lr_gap, opt.decay_epoch + opt.lr_gap*2, opt.decay_epoch + opt.lr_gap*3] + + total_train_batches = len(train_dataloader) + # snapshot_interval = round(total_train_batches/2) + snapshot_interval = 30 + + # ---------- + # Training + # ---------- + for epoch in range(opt.resume_epoch, opt.n_epochs): + + epoch_start_time = time.time() + + # Adjust LR + if epoch in milestones: + optimizer_G.param_groups[0]['lr'] *= 0.5 + optimizer_D.param_groups[0]['lr'] *= 0.5 + + for i, imgs in enumerate(train_dataloader): + + batches_done = epoch * total_train_batches + i + 1 + + # this will add channel axis: (Batch Size, Height, Width) --> (Batch Size, 1, Height, Width) + sparse_temp = torch.unsqueeze(imgs["sparse"], 1) + gt_temp = torch.unsqueeze(imgs["gt"], 1) + rgb_temp = imgs["rgb"] + + # Configure model input + sparse_depth = Variable(sparse_temp.type(Tensor)) + gt_depth = Variable(gt_temp.type(Tensor)) + imgs_rgb = Variable(rgb_temp.type(Tensor)) + + #send equal batch partitions to differnt gpus + sparse_depth, gt_depth, imgs_rgb = sparse_depth.to('cuda'), gt_depth.to('cuda'), imgs_rgb.to('cuda') + + # Adversarial ground truths + valid = Variable(Tensor(np.ones((imgs_rgb.size(0), *discriminator.module.output_shape))), requires_grad=False) + fake = Variable(Tensor(np.zeros((imgs_rgb.size(0), *discriminator.module.output_shape))), requires_grad=False) + + valid, fake = valid.to('cuda'), fake.to('cuda') + + # ----------------- + # Train Generator + # ----------------- + + optimizer_G.zero_grad() + + # Construct a depth map using RGB and lidar data + gen_depth = generator(imgs_rgb, sparse_depth) + + if "nyu" in opt.model: + gen_depth = gen_depth[:,:,6:-6,:] + gt_depth = gt_depth[:,:,6:-6,:] + + # Measure pixel-wise loss against ground truth + loss_pixel = criterion_pixel(gen_depth, gt_depth) + writer.add_scalar("Pixel_Loss/Train", loss_pixel, batches_done) + + # log learning rate + gen_lr = optimizer_G.param_groups[0]['lr'] + writer.add_scalar("Generateor_LR", gen_lr, batches_done) + + if batches_done < opt.warmup_batches: + # Warm-up (pixel-wise loss only) + loss_pixel.backward() + optimizer_G.step() + logger.info( + "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]" + % (epoch, opt.n_epochs-1, i+1, len(train_dataloader), loss_pixel.item()) + ) + continue + + # Extract validity predictions from discriminator + pred_real = discriminator(gt_depth).detach() + pred_fake = discriminator(gen_depth) + + # Adversarial loss (relativistic average GAN) + loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid) + writer.add_scalar("GAN_Loss/Train", loss_GAN, batches_done) + + gen_features = imgrad_yx(gen_depth) + real_features = imgrad_yx(gt_depth).detach() + loss_content = criterion_content(gen_features, real_features) + writer.add_scalar("Content_Loss/Train", loss_content, batches_done) + + # Total generator loss + loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel + writer.add_scalar("Generator_Loss/Train", loss_G, batches_done) + # loss_G = opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel + + loss_G.backward() + optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + pred_real = discriminator(gt_depth) + pred_fake = discriminator(gen_depth.detach()) + + # Adversarial loss for real and fake images (relativistic average GAN) + loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid) + writer.add_scalar("Discriminator_RealLoss/Train", loss_real, batches_done) + loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake) + writer.add_scalar("Discriminator_FakeLoss/Train", loss_fake, batches_done) + + # Total loss + loss_D = (loss_real + loss_fake) / 2 + writer.add_scalar("Discriminator_Loss/Train", loss_D, batches_done) + + #Discriminator LR + disc_lr = optimizer_D.param_groups[0]['lr'] + writer.add_scalar("Discriminator_LR", disc_lr, batches_done) + + loss_D.backward() + optimizer_D.step() + + # -------------- + # Log Progress + # -------------- + + logger.info( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f, lr: %f]" #removed content loss + % ( + epoch, + opt.n_epochs-1, + i+1, + len(train_dataloader), + loss_D.item(), + loss_G.item(), + loss_content.item(), # No content loss + loss_GAN.item(), + loss_pixel.item(), + gen_lr, + ) + ) + + if batches_done % snapshot_interval == 0: + # Save model checkpoints + generator_chkpt = os.path.join(saved_model_path,"generator_%d.pth" % epoch) + torch.save(generator.state_dict(), generator_chkpt) + discriminator_chkpt = os.path.join(saved_model_path,"discriminator_%d.pth" % epoch) + torch.save(discriminator.state_dict(), discriminator_chkpt) + logger.info("Saved Checkpoint at batch {}...".format(batches_done)) + + + if batches_done % snapshot_interval == 0: + with torch.no_grad(): + avg_rmse, avg_rel = validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GAN, criterion_content, criterion_pixel, logger, val_image_save_path, writer, batches_done) + + # save best checkpoint + if avg_rmse None: + + #convering to uint16 -> Grayscale + _,h,w,c = image_array.shape + image_array = image_array.reshape(-1,w,c) + image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR) + # print('image_array',image_array.shape) + cv2.imwrite(fp, image_array) + +def save_sample_images(gt_depth, imgs_rgb, sparse_depth, gen_depth, image_save_path, image_id) -> None: + + denorm_gt = denormalize_dense(gt_depth) + denorm_sparse = denormalize_sparse(sparse_depth) + denorm_pred = denormalize_dense(gen_depth) + + gt_depth = generate_depth_cmap(denorm_gt) + sparse_depth = generate_depth_cmap(denorm_sparse) + gen_depth = generate_depth_cmap(denorm_pred) + + imgs_rgb = denormalize_rgb(imgs_rgb).permute(0,2,3,1).to('cpu').detach().numpy() + + img_grid = np.concatenate((gt_depth, imgs_rgb, sparse_depth, gen_depth), axis=2) + saved_image_file = os.path.join(image_save_path,"%04d.png"%image_id) + save_my_image(img_grid, saved_image_file) \ No newline at end of file diff --git a/validate.py b/validate.py new file mode 100644 index 0000000..2d5acc4 --- /dev/null +++ b/validate.py @@ -0,0 +1,295 @@ +""" +Standalone validation script for the robust-multimodal-fusion-gan +In order to invoke type: + +python validate.py --model nyu_modelA --gpus=0 --batch_size=16 --checkpoint_model=./logdir/nyu_train/saved_models/ -n nyu_test + +1. The checkpoint model path has to have 2 files named generator_best.pth and discriminator_best.pth +2. -n --> give a name to the run +3. Modify the val dataloader path with appropriate data directory +4. Typically the directory has the following structure + ----|->data.ShapeNetDepth| + |->train| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + |->val| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + |->sample| + |->sparse_depth + |->depth_gt + |->image_rgb + |->meta_info.txt + +5. The "depth_gt" and "lidar" are the folders containing dense and sparse depth respectively +6. The meta_info.txt contains the file names of these folders. Refer to misc/ folder for sample meta_info file +7. The folder "sample" contains a few sparse samples. This is to track the model learning visually. +""" +import argparse +import os +import numpy as np +import math +import itertools +import sys +import time + +import torchvision.transforms as transforms + +from torch.utils.data import DataLoader +from torch.autograd import Variable + +from models.generator_models import * +from models.models import * +from datasets import * +from utils import * + +import torch.nn as nn +import torch.nn.functional as F +import torch + +from torchsummary import summary +from torch.utils.tensorboard import SummaryWriter + +LOGDIR = "./logdir/" + +def getOpt(): + + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, default="nyu_v2", help="name of the dataset (shapeNet or nyu_v2)") + parser.add_argument("--model", type=str, default="nyu_modelA", required = True, help="name of the model (nyu_modelA | nyu_modelB)") + parser.add_argument("--dataset_path", type=str, default="/home/mdl/mzk591/dataset/data.nyuv2/disk3/", help="path to the dataset") + parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") + parser.add_argument("--save_size", type=int, default=8, help="batch size for saved outputs") + parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation") + parser.add_argument("--channels", type=int, default=1, help="number of image channels") + parser.add_argument("--validation_interval", type=int, default=4000, help="interval between two consecutive validations") + parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight") + parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight") + parser.add_argument("--gpus", metavar='DEV_ID', default=None, + help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)') + parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') + parser.add_argument('--meta_info_file', '-m', metavar='DIR', default="meta_info.txt", help='Meta file name') + parser.add_argument("--checkpoint_model_path", type=str, required=True, help="Path to checkpoint model") + + return parser.parse_args() + +def validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GAN, criterion_content, criterion_pixel, logger, val_image_save_path, writer, batches_done=0): + + total_val_batches = len(val_dataloader) + + batch_to_be_saved = np.random.randint(total_val_batches, size=5) + # batch_to_be_saved = [1, 2, 3, 4] #it can be any numbers + + val_sample_path = os.path.join(val_image_save_path,"%06d"%batches_done) + os.makedirs(val_sample_path, exist_ok=True) + + loss_dict = {'rmse':[],'rel':[], 'mae':[]} + + for i, imgs in enumerate(val_dataloader): + + # this will add channel axis: (Batch Size, Height, Width) --> (Batch Size, 1, Height, Width) + sparse_temp = torch.unsqueeze(imgs["sparse"], 1) + gt_temp = torch.unsqueeze(imgs["gt"], 1) + rgb_temp = imgs["rgb"] + + # Configure model input + sparse_depth = Variable(sparse_temp.type(Tensor)) + gt_depth = Variable(gt_temp.type(Tensor)) + imgs_rgb = Variable(rgb_temp.type(Tensor)) + + #send equal batch partitions to differnt gpus + sparse_depth, gt_depth, imgs_rgb = sparse_depth.to('cuda'), gt_depth.to('cuda'), imgs_rgb.to('cuda') + + # Adversarial ground truths + valid = Variable(Tensor(np.ones((imgs_rgb.size(0), *discriminator.module.output_shape))), requires_grad=False) + + gen_depth = generator(imgs_rgb, sparse_depth) + + if "nyu" in opt.model: + gen_depth = gen_depth[:,:,6:-6,:] + gt_depth = gt_depth[:,:,6:-6,:] + sparse_depth = sparse_depth[:,:,6:-6,:] + imgs_rgb = imgs_rgb[:,:,6:-6,:] + + '''calculation of content, pixel and GAN loss is optional''' + + # Extract validity predictions from discriminator + pred_real = discriminator(gt_depth).detach() + pred_fake = discriminator(gen_depth) + + # Adversarial loss (relativistic average GAN) + loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid) + + gen_features = imgrad_yx(gen_depth) + real_features = imgrad_yx(gt_depth).detach() + loss_content = criterion_content(gen_features, real_features) + + # Measure pixel-wise loss against ground truth + loss_pixel = criterion_pixel(gen_depth, gt_depth) + + # Total generator loss + loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel + + '''calculation of content, pixel and GAN loss is optional''' + + pred, gt = gen_depth.detach().clone(), gt_depth.detach().clone() + + pred = denormalize_dense(pred) + gt = denormalize_dense(gt) + + #new loss measures + loss_rmse, loss_rel, loss_mae = get_loss(pred, gt) + + loss_dict['rmse'].append(loss_rmse.item()) + loss_dict['rel'].append(loss_rel.item()) + loss_dict['mae'].append(loss_mae.item()) + + logger.info( + "Validating [Batch %d/%d] [content: %.3f, pixel: %.3f, RMSE: %.3f, REL: %.3f, MAE: %.3f]" #removed content loss + % ( + i+1, + len(val_dataloader), + loss_content.item(), # No content loss + loss_pixel.item(), + loss_rmse.item(), + loss_rel.item(), + loss_mae.item(), + ) + ) + + if i in batch_to_be_saved: + + save_sample_images(gt_depth, imgs_rgb, sparse_depth, gen_depth, val_sample_path, i) + logger.info("Saved Validation Images...") + + avg_rmse = np.sqrt(np.mean(np.square(loss_dict['rmse']))) + avg_rel = np.mean(loss_dict['rel']) + avg_mae = np.mean(loss_dict['mae']) + + writer.add_scalar("Final_RMSE_mean", avg_rmse, batches_done) + writer.add_scalar("Final_REL_mean", avg_rmse, batches_done) + writer.add_scalar("Final_MAE_mean", avg_mae, batches_done) + + logger.info( + "Final Avg loss after %d batches [RMSE: %.3f, REL: %.3f, MAE: %.3f]" #removed content loss + % ( + batches_done, + avg_rmse, + avg_rel, + avg_mae, + ) + ) + + return avg_rmse, avg_rel + +def main(): + + opt = getOpt() + + # create the logdir if it does not exist + os.makedirs(LOGDIR, exist_ok=True) + + val_image_save_path = os.path.join(LOGDIR,opt.name,"val_images") + log_file_name = os.path.join(LOGDIR,opt.name,'%s.log'%opt.name) + tensorboard_save_path = os.path.join(LOGDIR,opt.name) + + os.makedirs(val_image_save_path, exist_ok=True) + + # Create a logger + logger = createLogger(log_file_name) + + # print(opt) + logger.info(opt) + + # initiate tensorboard logger + writer = SummaryWriter(log_dir=tensorboard_save_path) + + + if opt.gpus is not None: + try: + opt.gpus = [int(s) for s in opt.gpus.split(',')] + except ValueError: + logger.error('ERROR: Argument --gpus must be a comma-separated list of integers only') + exit(1) + available_gpus = torch.cuda.device_count() + for dev_id in opt.gpus: + if dev_id >= available_gpus: + logger.error('ERROR: GPU device ID {0} requested, but only {1} devices available' + .format(dev_id, available_gpus)) + exit(1) + # Set default device in case the first one on the list != 0 + torch.cuda.set_device(opt.gpus[0]) + + + if 'shapeNet' in opt.model: + hr_shape = (192, 256) + elif "nyu" in opt.model: + hr_shape = (240, 304) + + model_config = { + "img_size": hr_shape, + "rgb_chans": 3, + "lidar_chans": 1, + "patch_size": 16, + "embed_dim": 768, + "depth": 12, + "n_heads": 12, + "qkv_bias": True, + "mlp_ratio": 4, + } + + # Initialize generator and discriminator + try: + generator = eval(opt.model)(**model_config) + except: + print("Please select model from: nyu_modelA | nyu_modelB") + quit() + + generator = nn.DataParallel(generator, device_ids = opt.gpus) + generator.cuda() + + discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)) + discriminator = nn.DataParallel(discriminator, device_ids = opt.gpus) + discriminator.cuda() + + # Losses + criterion_GAN = torch.nn.BCEWithLogitsLoss().cuda() + criterion_content = NormalLoss().cuda() + criterion_pixel = torch.nn.L1Loss().cuda() + + # Load state dict for generator and discriminator + saved_generator_chkpt = os.path.join(opt.checkpoint_model_path,"generator_best.pth") + generator.load_state_dict(torch.load(saved_generator_chkpt)) + saved_discriminator_chkpt = os.path.join(opt.checkpoint_model_path,"discriminator_best.pth") + discriminator.load_state_dict(torch.load(saved_discriminator_chkpt)) + + # Only evaluate + generator.eval() + discriminator.eval() + + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor + + val_path = os.path.join(opt.dataset_path, "val") + + ## Need to use PairedImageDataset Dataset class + val_dataloader = DataLoader( + PairedImageDataset(val_path, opt, hr_shape=hr_shape), + batch_size=opt.save_size, + num_workers=opt.n_cpu, + ) + + # final validation + with torch.no_grad(): + avg_rmse, avg_rel =validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GAN, criterion_content, criterion_pixel, logger, val_image_save_path, writer) + + writer.flush() + writer.close() + + logger.info("Validation Done. Check results!") + +if __name__=='__main__': + main() \ No newline at end of file