[arXiv
]
This repo contains the Pytorch implementation of our ICLR 2024 paper (spotlight):
CrIBo: Self-Supervised Learning via Cross-Image Object-Level Bootstrapping
Tim Lebailly*, Thomas Stegmüller*, Behzad Bozorgtabar, Jean-Philippe Thiran and Tinne Tuytelaars.
Our code only has a few dependencies. First, install PyTorch for your machine following https://pytorch.org/get-started/locally/. Then, install other needed dependencies:
pip install einops
Run the main_cribo.py file. Command line args are defined in parser.py.
python main_cribo.py
Make sure to use the right arguments specified in the table below!
The code is compatible with slurm. For running on a single node with 8 GPUs:
#!/bin/bash
#SBATCH --job-name=cribo
#SBATCH --account=<slurm_account>
#SBATCH --cpus-per-task=7
#SBATCH --gpus-per-node=8
#SBATCH --mem=<mem>
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --open-mode=append
#SBATCH --time=4320
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
export MASTER_PORT=12802
srun --unbuffered \
python main_cribo.py \
--imagenet1k_path /path/to/ilsvrc2012 \
--output_dir . \
--n_tokens 32 \
--queue_size 25000 \
--pos_alpha 1.0 1.0 \
--arch vit_small
Alternatively, you can use torchrun
or torch.distributed.launch
(deprecated):
python -m torch.distributed.launch --nproc_per_node=8 main_cribo.py \
--imagenet1k_path /path/to/ilsvrc2012 \
--output_dir . \
--n_tokens 32 \
--queue_size 25000 \
--pos_alpha 1.0 1.0 \
--arch vit_small
If you find our work useful, please consider citing:
@misc{lebailly2023cribo,
title={CrIBo: Self-Supervised Learning via Cross-Image Object-Level Bootstrapping},
author={Tim Lebailly and Thomas Stegmüller and Behzad Bozorgtabar and Jean-Philippe Thiran and Tinne Tuytelaars},
year={2023},
eprint={2310.07855},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
You can download the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide detailed arguments to reproduce our results.
pretraining dataset | arch | params | batchsize | Dense NN retrieval ADE20k (mIoU) | Dense NN retrieval PVOC12 (mIoU) | download | |
---|---|---|---|---|---|---|---|
COCO | ViT-S/16 | 21M | 256 | 23.4 | 58.1 | ckpt | args |
ImageNet-1k | ViT-S/16 | 21M | 1024 | 28.3 | 73.2 | ckpt | args |
ImageNet-1k | ViT-B/16 | 85M | 1024 | 30.0 | 74.9 | ckpt | args |
import torch
checkpoint_to_load = 'vitb16-in.pth' # Choose the checkpoint to load
loaded = torch.load(checkpoint_to_load, map_location='cpu')
print(loaded.keys())
This code is adapted from CrOC which is based on the codebase of DINO.