forked from IvLabs/Summer-Projects
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from Cpt-Shaan/main
fsrcnn-merge-request
- Loading branch information
Showing
20 changed files
with
3,284 additions
and
0 deletions.
There are no files selected for viewing
1,556 changes: 1,556 additions & 0 deletions
1,556
Summer 2024/FSRCNN/FSRCNN Implementation/FSRCNN.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
from torch.utils.data import Dataset, DataLoader | ||
import math | ||
import torchvision | ||
import torchvision.transforms.v2 as transforms | ||
import torchvision.transforms.functional as TF | ||
import random | ||
|
||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
import os | ||
import sys | ||
import matplotlib.pyplot as plt | ||
from collections import namedtuple | ||
from torchvision import models | ||
|
||
# device configuration | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
# dataset class | ||
|
||
class SRdatasets(Dataset): | ||
def __init__(self, dataset_path = 'C:/Users/athar/MLprojects/dataloader_task/Datasets', transform = None, scale_factor = 2): | ||
script_directory = os.path.dirname(os.path.abspath(sys.argv[0])) | ||
os.chdir(dataset_path) | ||
input_list = [] | ||
target_list = [] | ||
|
||
for dset_name in os.listdir(): | ||
#{ | ||
os.chdir(dataset_path + '\\' + dset_name) | ||
for dir_name in os.listdir(): | ||
#{ | ||
dir_num = int(dir_name[-1]) | ||
|
||
# changing directory to export images | ||
cwd = os.getcwd() | ||
os.chdir(cwd + '\\' + dir_name) | ||
|
||
num_images = len(os.listdir()) | ||
num_images = num_images//2 | ||
for i in range(1, num_images): | ||
input_arr, target = self.extract("LR", dir_num, i), self.extract("HR", dir_num, i) | ||
|
||
# adding horizontal flip deterministic transform | ||
hflip_input = TF.hflip(torch.from_numpy(input_arr)) | ||
hflip_target = TF.hflip(torch.from_numpy(target)) | ||
|
||
input_list.extend([input_arr, hflip_input]) | ||
target_list.extend([target, hflip_target]) | ||
|
||
# returning back to the dset directory | ||
os.chdir(dataset_path + '\\' + dset_name) | ||
#} | ||
#} | ||
|
||
# returning back to the current directory | ||
# this way the flow of rest of the program is not affected | ||
os.chdir(script_directory) | ||
|
||
# converting the list of np.arrays into higher dimenstion np.array since list -> tensor conversion is much slower | ||
input_arr = np.array(input_list) | ||
target_arr = np.array(target_list) | ||
|
||
self.input_data = torch.Tensor(input_arr) # shape: (num_channels = 3, height, width) | ||
self.target_data = torch.Tensor(target_arr) | ||
self.size = len(self.input_data) | ||
self.transform = transform | ||
|
||
|
||
|
||
|
||
def extract(self, res = "LR", dir_num = 2, i = 1, scale_factor = 2): | ||
leading_zeros = 3 - len(str(i)) | ||
number_str = leading_zeros * "0" + str(i) | ||
|
||
final_str = "img_" + number_str + "_SRF_" + str(dir_num) + "_" + res + ".png" | ||
img = Image.open(final_str) | ||
npimg = np.asarray(img) # npimg.shape(480, 320, 3) or (320, 480, 3) | ||
|
||
|
||
# resizing all images into same dimensions | ||
if res == "LR": | ||
scale = scale_factor | ||
else: | ||
scale = 1 | ||
npimg = cv2.resize(npimg, dsize = (320//scale, 480//scale), interpolation = cv2.INTER_CUBIC) | ||
npimg = np.array(npimg) | ||
if len(npimg.shape) == 2: | ||
npimg = cv2.cvtColor(npimg, cv2.COLOR_GRAY2BGR) | ||
|
||
return np.transpose(npimg, axes = (2, 0, 1)) # returning in form (num_channels = 3, rows, col) | ||
|
||
def __getitem__(self, index): | ||
if isinstance(index, int): | ||
input_data, target_data = self.input_data[index], self.target_data[index] | ||
return input_data, target_data | ||
|
||
if isinstance(index, slice): | ||
return torch.stack([(self.input_data[i]) for i in range(*index.indices(len(self)))]), torch.stack([(self.target_data[i]) for i in range(*index.indices(len(self)))]) | ||
# return type: tuple of the form: (tensor of inputs, tensor of targets) | ||
# shape of each tensor: (size_of_slice, num_channels = 3, height, width) | ||
|
||
def __len__(self): | ||
return self.size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
from torch import nn | ||
# aliter: import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
from torch.utils.data import Dataset, DataLoader | ||
import math | ||
import torchvision | ||
import torchvision.transforms.v2 as transforms | ||
import torchvision.transforms.functional as TF | ||
import random | ||
|
||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
import os | ||
import sys | ||
import matplotlib.pyplot as plt | ||
from collections import namedtuple | ||
from torchvision import models | ||
|
||
# device configuration | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
# defining model | ||
class FSRCNN(nn.Module): | ||
def __init__(self, scale_factor = 2, num_channels=1, d=56, s=12, m=4): | ||
super(FSRCNN, self).__init__() | ||
self.first_part = nn.Sequential( | ||
nn.Conv2d(num_channels, d, kernel_size=5, padding=2), | ||
nn.PReLU(d) | ||
) | ||
|
||
self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)] | ||
for _ in range(m): | ||
self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=1), nn.PReLU(s)]) | ||
self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)]) | ||
self.mid_part = nn.Sequential(*self.mid_part) | ||
# In Python, the * operator is used for unpacking an iterable (like a list or tuple) into individual elements. | ||
|
||
self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=4, | ||
output_padding=scale_factor-1) | ||
|
||
self._initialize_weights() | ||
|
||
def _initialize_weights(self): | ||
for m in self.first_part: | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) | ||
# this formula is inspired from He initialisation used in Fully connected ANNs | ||
nn.init.zeros_(m.bias.data) | ||
for m in self.mid_part: | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) | ||
nn.init.zeros_(m.bias.data) | ||
nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) | ||
nn.init.zeros_(self.last_part.bias.data) | ||
|
||
def forward(self, x): | ||
x = x/255.0 | ||
x = self.first_part(x) | ||
|
||
# to ease in debugging: | ||
if torch.isnan(x).any(): | ||
print("NaN detected in first part") | ||
x = self.mid_part(x) | ||
if torch.isnan(x).any(): | ||
print("NaN detected in mid part") | ||
x = self.last_part(x) | ||
if torch.isnan(x).any(): | ||
print("NaN detected in last part") | ||
|
||
''' | ||
As an alternative to min-max normalisation, A modified sigmoid was thougth about, but due to the | ||
problem of vanishing gradient it was discarded: | ||
# n = 3 might give good results (to tackle the vanishing gradient problem as posed by sigmoid function) | ||
n = 1 | ||
x = torch.sigmoid(x/n) | ||
x = x*255.0 | ||
''' | ||
|
||
# min max normalise ( done in a cascading fassion ) | ||
min_values, _ = torch.min(x, dim = -1, keepdim = True) | ||
min_values, _ = torch.min(min_values, dim = -2, keepdim = True) | ||
# expected shape: (batch_size, num_channels, 1, 1) | ||
max_values, _ = torch.max(x, dim = -1, keepdim = True) | ||
max_values, _ = torch.max(max_values, dim = -2, keepdim = True) | ||
|
||
# broadcasting expected along dimensions -1 and -2 | ||
x = (x-min_values)/(max_values-min_values) | ||
x = x * 255.0 | ||
|
||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
import torchvision | ||
from torch.utils.data import Dataset, DataLoader | ||
import math | ||
import torchvision | ||
import torchvision.transforms.v2 as transforms | ||
import torchvision.transforms.functional as TF | ||
import random | ||
|
||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
import os | ||
import sys | ||
import matplotlib.pyplot as plt | ||
from collections import namedtuple | ||
from torchvision import models | ||
|
||
# device configuration | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
# transfer model | ||
def transfer_model(PATH, dataset, model, optimizer, epoch_loss_list, epoch_acc_list, train_size, test_size, prev_epochs = 0): | ||
# transfering previous checkpoint | ||
try: | ||
#{ | ||
checkpoint = torch.load(PATH, weights_only = True) | ||
print('checkpoint loaded successfully') | ||
transfer = int(input("transfer previous model? 1/0: ")) | ||
if(transfer == 1): | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
epoch_loss_list.extend(checkpoint['epoch_loss_list']) | ||
epoch_acc_list.extend(checkpoint['epoch_acc_list']) | ||
prev_epochs = checkpoint['epoch'] | ||
print('model transfered successfully') | ||
#} | ||
except Exception as e: | ||
print('Exception occured, running without loading checkpoint') | ||
checkpnt_flag = 0 | ||
print(e) | ||
|
||
|
||
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size]) | ||
|
||
return train_set, test_set, prev_epochs # returning prev_epochs is important as in python there is no straightforward provision to modify integers inside functions (passing by reference) |
Oops, something went wrong.