This project is an image classifier that uses pre-trained models such as VGG16 and Resnet18 to classify images of different classes. The user can specify the architecture of the model, the learning rate, the number of hidden units, the number of training epochs, and whether to use GPU for training when running the script.
- torch
- torchvision
- argparse
- matplotlib
- numpy
- PIL
The script can be run using the following command:
python train.py dataset_folder
dataset_folder
: directory containing the training datasave_dir
(optional): directory to save checkpointsarch
(optional): model architecture, can be either "vgg16" or "resnet18"learning_rate
(optional): learning rate for the optimizerhidden_units
(optional): number of hidden units for the classifierepochs
(optional): number of training epochsgpu
(optional): flag to use GPU for training
The script expects the dataset to be structured as follows:
dataset_folder | |__ train | | | |__ class1 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ class2 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ ... | |__ valid | | | |__ class1 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ class2 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ ... | |__ test | |__ class1 | | | |__ image1.jpg | |__ image2.jpg | |__ ... | |__ class2 | | | |__ image1.jpg | |__ image2.jpg | |__ ... | |__ ...
The script will output the number of images in each dataset and the classes. It will also save the trained model to the specified save_dir
with the name checkpoint.pth
.
It will also display the loss and accuracy of the model after each epoch.