Skip to content

Latest commit

 

History

History
49 lines (36 loc) · 3.76 KB

README.md

File metadata and controls

49 lines (36 loc) · 3.76 KB

CountCLIP : [Re] Teaching Clip to Count to Ten

This repository contains the implementation of the paper Teaching Clip to Count to Ten by Google Research, published in ICCV 2023. This paper presented a method to fine-tune Vision-Language Models (VLMs), like CLIP, to improve zero-shot counting accuracy in an image while maintaining the performance for zero-shot classification by introducing a counting-contrastive loss term to the original loss function. This changes the training objective to discriminate between the correct and the incorrect captions associated with the object counts in an image.

Demo of our model learning to count
Demo of our model learning to count

Usage

Colab Demo: Open In Colab

To run the Python script (recommended version python 3.10), run the following after downloading the dataset files in the scripts folder:

 git clone https://github.com/SforAiDl/CountCLIP.git
 cd CountCLIP/scripts  
 conda create -n <env_name> python=3.10  
 pip install requirements.txt  
 python3 experiment.py  

Repository structure

  • count_set_gen.ipynb contains the implementation for generating the counting set as described in Section 3.1 of the paper.
  • model.ipynb contains the implementation for the counting loss function as described in Section 3.2 of the paper.
  • The folder data_utils contains miscellaneous notebooks for downloading data, merging datasets etc.
  • The folder old contains incomplete and outdated code used to make the final implementation.

Dataset

We have created a small counting set of ~2000 images after passing over 2 million images out of the 400 million present in the original dataset. This is merged with ~13000 non-counting images from the same dataset. The entire merged dataset, along with the required relevant JSON/CSV files, can be found here DOI .

  • data.zip - merged counting and noncounting data, along with the validation data (the CountBench dataset).
  • merged.json - JSON for merged (counting+noncounting) data.
  • val.json - JSON for the CountBench data.
  • faulty.csv - CSV for removing faulty noncounting images.

Special Thanks