This repo provide Full Implmentation of VisionTransformer following Series of Article About Vision-Langauge model , in project i implmented from Scratch using Pytorch we need to kepp mind that there's no big different only few modifications which include
Vision Transformer included few modification in Architicture main are :
-
Linear Projection: that used Convolution Network but not in matter of extract features instead of it used to split the image of size 256x256 into sub-Patches to make the model Transformer able to learn and process the Image because of most used in NLP Seq2Seq modeling , here Linear Projection is make each Patch as Token in Vector
-
MLP multi- Layer Perceptence: to make the model do the task classification used MLP because is widely implemented in Classification only we add CLS
Notation in Vision-Transformer only we take Encoder blocks instead of all the model Transfotmer for me infotmation read the Article
first Creat an ENV to run the poject in Dir
Packages:
- numpy
- torchmetrics
- matplotlib
- torch
- torchvision
- pytorch-lightning
- opencv-python
- create the enviromenet here you will need to run
conda create --name Segemnetation python=3.6
- make sure the requirements.txt exist to the repo the packges if you want fisrt neeed to run
pip install -r requirements.txt
in The transformer model there's many og Hyper-Parameters to tune baed on the exprement and data Size , to make easy to Tune the model there's Script Called CONFIG.py contain all the Parameters setup based on your purpose it will automatically Generate YAML config.yml FILE
Notation : in this project i used Pytorch-Lighting Framework because is easy to creat Loop Traning and use Mulit-GPU
to Run the model Traninig Folowwing Commmand : after finishing the Traning auto-Checkpoint Save model Called ViT.ckpt will save in current Folder project
python train.py --Config config.yml --device "gpu"
After the Training is done Run Predict.py to check the prediction using Save CHECKPOINT following Command: Path_Checkpoint
INSERT_Val_DATA: this one should be Validation data or TestDATA already processed
python predict.py --Path INSERT_Val_DATA -- Path_Checkpoint INSERT_CHECKPOINT_MODEL --OUTPUT INSERT_OUTPUT_STR.PNG