This project implements a Residual Network (ResNet) model using PyTorch for image classification on the CIFAR-10 dataset. The ResNet architecture is enhanced with various normalization techniques to compare their effects on model performance.
The project consists of the following components:The data_loader function loads and preprocesses the CIFAR-10 dataset. It applies transformations such as resizing, converting to tensors, and normalization. The function supports loading both training and validation data as well as the test data.
This function provides a selection of different normalization layers: Batch Normalization, Group Normalization, Instance Normalization, and Layer Normalization. The selected normalization technique is used in the convolutional layers of the ResNet.
The ResidualBlock class defines a single residual block used within the ResNet architecture. Each block contains two convolutional layers with the chosen normalization technique and a residual connection. The downsample option is used for downsampling the input when needed.
The ResNet class constructs the ResNet architecture with multiple layers of residual blocks. The architecture includes initial convolutional layers, max-pooling, and several layers of residual blocks with varying numbers of channels and layers. The final classification is performed using an average pooling layer followed by a fully connected layer.
The script includes a training loop that iterates over different normalization techniques. For each technique, the model is trained on the training dataset and evaluated on the validation dataset. The training progress, loss, and validation accuracy are displayed.
After training with different normalization techniques, the project generates a bar plot to compare the accuracies achieved by each technique on the validation dataset.
The code demonstrates the impact of different normalization techniques on the ResNet model's performance. After training for a certain number of epochs, the script presents the accuracy achieved by each technique on the validation dataset. The final plot visually compares the accuracy achieved by each normalization technique.