Skip to content

Commit

Permalink
Merge pull request #5 from Cpt-Shaan/main
Browse files Browse the repository at this point in the history
fsrcnn-merge-request
  • Loading branch information
ChinmayK0607 authored Dec 24, 2024
2 parents 1feb297 + 3110178 commit 7d79008
Show file tree
Hide file tree
Showing 20 changed files with 3,284 additions and 0 deletions.
1,556 changes: 1,556 additions & 0 deletions Summer 2024/FSRCNN/FSRCNN Implementation/FSRCNN.ipynb

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions Summer 2024/FSRCNN/FSRCNN Implementation/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as TF
import random

import numpy as np
import cv2
from PIL import Image
import os
import sys
import matplotlib.pyplot as plt
from collections import namedtuple
from torchvision import models

# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataset class

class SRdatasets(Dataset):
def __init__(self, dataset_path = 'C:/Users/athar/MLprojects/dataloader_task/Datasets', transform = None, scale_factor = 2):
script_directory = os.path.dirname(os.path.abspath(sys.argv[0]))
os.chdir(dataset_path)
input_list = []
target_list = []

for dset_name in os.listdir():
#{
os.chdir(dataset_path + '\\' + dset_name)
for dir_name in os.listdir():
#{
dir_num = int(dir_name[-1])

# changing directory to export images
cwd = os.getcwd()
os.chdir(cwd + '\\' + dir_name)

num_images = len(os.listdir())
num_images = num_images//2
for i in range(1, num_images):
input_arr, target = self.extract("LR", dir_num, i), self.extract("HR", dir_num, i)

# adding horizontal flip deterministic transform
hflip_input = TF.hflip(torch.from_numpy(input_arr))
hflip_target = TF.hflip(torch.from_numpy(target))

input_list.extend([input_arr, hflip_input])
target_list.extend([target, hflip_target])

# returning back to the dset directory
os.chdir(dataset_path + '\\' + dset_name)
#}
#}

# returning back to the current directory
# this way the flow of rest of the program is not affected
os.chdir(script_directory)

# converting the list of np.arrays into higher dimenstion np.array since list -> tensor conversion is much slower
input_arr = np.array(input_list)
target_arr = np.array(target_list)

self.input_data = torch.Tensor(input_arr) # shape: (num_channels = 3, height, width)
self.target_data = torch.Tensor(target_arr)
self.size = len(self.input_data)
self.transform = transform




def extract(self, res = "LR", dir_num = 2, i = 1, scale_factor = 2):
leading_zeros = 3 - len(str(i))
number_str = leading_zeros * "0" + str(i)

final_str = "img_" + number_str + "_SRF_" + str(dir_num) + "_" + res + ".png"
img = Image.open(final_str)
npimg = np.asarray(img) # npimg.shape(480, 320, 3) or (320, 480, 3)


# resizing all images into same dimensions
if res == "LR":
scale = scale_factor
else:
scale = 1
npimg = cv2.resize(npimg, dsize = (320//scale, 480//scale), interpolation = cv2.INTER_CUBIC)
npimg = np.array(npimg)
if len(npimg.shape) == 2:
npimg = cv2.cvtColor(npimg, cv2.COLOR_GRAY2BGR)

return np.transpose(npimg, axes = (2, 0, 1)) # returning in form (num_channels = 3, rows, col)

def __getitem__(self, index):
if isinstance(index, int):
input_data, target_data = self.input_data[index], self.target_data[index]
return input_data, target_data

if isinstance(index, slice):
return torch.stack([(self.input_data[i]) for i in range(*index.indices(len(self)))]), torch.stack([(self.target_data[i]) for i in range(*index.indices(len(self)))])
# return type: tuple of the form: (tensor of inputs, tensor of targets)
# shape of each tensor: (size_of_slice, num_channels = 3, height, width)

def __len__(self):
return self.size
95 changes: 95 additions & 0 deletions Summer 2024/FSRCNN/FSRCNN Implementation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from torch import nn
# aliter: import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as TF
import random

import numpy as np
import cv2
from PIL import Image
import os
import sys
import matplotlib.pyplot as plt
from collections import namedtuple
from torchvision import models

# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# defining model
class FSRCNN(nn.Module):
def __init__(self, scale_factor = 2, num_channels=1, d=56, s=12, m=4):
super(FSRCNN, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=5, padding=2),
nn.PReLU(d)
)

self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
for _ in range(m):
self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=1), nn.PReLU(s)])
self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
self.mid_part = nn.Sequential(*self.mid_part)
# In Python, the * operator is used for unpacking an iterable (like a list or tuple) into individual elements.

self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=4,
output_padding=scale_factor-1)

self._initialize_weights()

def _initialize_weights(self):
for m in self.first_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
# this formula is inspired from He initialisation used in Fully connected ANNs
nn.init.zeros_(m.bias.data)
for m in self.mid_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
nn.init.zeros_(self.last_part.bias.data)

def forward(self, x):
x = x/255.0
x = self.first_part(x)

# to ease in debugging:
if torch.isnan(x).any():
print("NaN detected in first part")
x = self.mid_part(x)
if torch.isnan(x).any():
print("NaN detected in mid part")
x = self.last_part(x)
if torch.isnan(x).any():
print("NaN detected in last part")

'''
As an alternative to min-max normalisation, A modified sigmoid was thougth about, but due to the
problem of vanishing gradient it was discarded:
# n = 3 might give good results (to tackle the vanishing gradient problem as posed by sigmoid function)
n = 1
x = torch.sigmoid(x/n)
x = x*255.0
'''

# min max normalise ( done in a cascading fassion )
min_values, _ = torch.min(x, dim = -1, keepdim = True)
min_values, _ = torch.min(min_values, dim = -2, keepdim = True)
# expected shape: (batch_size, num_channels, 1, 1)
max_values, _ = torch.max(x, dim = -1, keepdim = True)
max_values, _ = torch.max(max_values, dim = -2, keepdim = True)

# broadcasting expected along dimensions -1 and -2
x = (x-min_values)/(max_values-min_values)
x = x * 255.0


return x
49 changes: 49 additions & 0 deletions Summer 2024/FSRCNN/FSRCNN Implementation/transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as TF
import random

import numpy as np
import cv2
from PIL import Image
import os
import sys
import matplotlib.pyplot as plt
from collections import namedtuple
from torchvision import models

# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# transfer model
def transfer_model(PATH, dataset, model, optimizer, epoch_loss_list, epoch_acc_list, train_size, test_size, prev_epochs = 0):
# transfering previous checkpoint
try:
#{
checkpoint = torch.load(PATH, weights_only = True)
print('checkpoint loaded successfully')
transfer = int(input("transfer previous model? 1/0: "))
if(transfer == 1):
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_loss_list.extend(checkpoint['epoch_loss_list'])
epoch_acc_list.extend(checkpoint['epoch_acc_list'])
prev_epochs = checkpoint['epoch']
print('model transfered successfully')
#}
except Exception as e:
print('Exception occured, running without loading checkpoint')
checkpnt_flag = 0
print(e)


train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])

return train_set, test_set, prev_epochs # returning prev_epochs is important as in python there is no straightforward provision to modify integers inside functions (passing by reference)
Loading

0 comments on commit 7d79008

Please sign in to comment.