A PyTorch implementation of "Synthesize Then Align: Modality Alignment Augmentation for Zero-shot Image Captioning with Synthetic Data".
Install the requirements with:
$ pip install -r requirements.txt
Download image caption datasets from the web. The data directory looks like:
data
├── mscoco
│ ├── train_captions.json #captions of training split
│ ├── refs.json #reference of test split
│ └── test
│ └── test_image #images of test split
├── flickr
│ ├── train_captions.json #captions of training split
│ ├── refs.json #reference of test split
│ └── test
│ └── test_image #images of test split
├── ss1m
│ └── train_captions.json #captions of training split
├── nocaps
│ ├── refs.json #reference of test split
│ └── test #images of test split
│ └── test_image
└── object_phrases.json #object phrases used to construct hard prompt
Download pretrained GPT2 from Huggingface into the ./model
directory. Stable Diffusion v1.5, Llama-3-8B and CLIP VIT-B/32 can be downloaded automatically within code. The directory looks like:
model
├── gpt2
| └── ... # Pretrained GPT2
We perform data preparation before proceeding with training, and realize this process by running the code under the ./data_preparation
directory.
-
Extract CLIP features for training and test
python data_prepare.py --dataset {mscoco|flickr30k|ss1m}
The following file is obtained after executing the above command, using MSCOCO as an example.
data ├── mscoco │ ├── train │ │ ├── text_feature_dict.pt #clip text feature of captions │ │ ├── synthetic_image_pk_dict.json #synthetic pseudo pair's map │ │ └── ... │ └── test │ ├── test_image_global_feature_dict.pt #clip image feature of test images │ └── ... ├── object_phrases_feature_dict.pt #clip text feature of object phrases
-
Image synthesis This section generates synthetic images for all conditional texts from the given corpus.
python image_synthesis.py --dataset {mscoco|flickr30k|ss1m}
The following directory or file is obtained after executing the above command, using MSCOCO as an example.
data ├── mscoco │ ├── train │ │ ├── synthetic_images #synthetic image of captions │ │ ├── synthetic_image_global_feature_dict.pt #image global feature of synthetic images │ │ └── ...
-
Rephrasing caption
python caption_rephrasing.py --dataset {mscoco|flickr30k|ss1m}
The following file is obtained after executing the above command, using MSCOCO as an example.
data ├── mscoco │ ├── train │ │ ├── llm_rephrasing_cap.json #texts from rephrasing │ │ └── ...
-
supporting image synthesis This section generates supporting images for texts from rephrasing.
python supporting_image_synthesis.py --dataset {mscoco|flickr30k|ss1m}
The following directory or file is obtained after executing the above command, using MSCOCO as an example.
data ├── mscoco │ ├── train │ │ ├── supporting_images #supporting image of text from rephrasing │ │ ├── supporting_image_global_feature_dict.pt #image global feature of supporting images │ │ ├── supporting_image_pk_dict.json #supporting image name and conditional text map │ │ └── ...
-
Constructing hard prompts This section retrieves the Top-N support features with the highest similarity to the target feature based on the cosine similarity of the CLIP features. It is used to retrieve relevant object phrases.
python hard_prompt_retrieval.py --dataset {mscoco|flickr30k|ss1m}
The following directory or file is obtained after executing the above command, using MSCOCO as an example.
data ├── mscoco │ ├── train │ │ ├── synthetic_image_hard_prompt_dict.json #shard Prompt of supporting images │ │ ├── supporting_image_hard_prompt_dict.json #hard Prompt of synthetic images │ │ └── ... │ ├── test │ │ ├── test_image_hard_prompt_dict.json #hard Prompt of test images │ │ └── ...
We use a re-pairing mechanism to construct training pairs for each iteration and augment the cross-modal alignment modeling with the soft prompt and hard prompt.
Running the following code will create a folder ./trained_model/{dataset}
in the root directory, and save the training log, argument, and model weights.
python training.py --dataset {mscoco|flickr30k|ss1m}
The following folder or file is obtained after executing the above command, using MSCOCO as an example.
trained_model
├── mscoco
| ├──captioner_{epoch}.pt #trained model
| ├──train_log.txt #training log
| └──train_args.json #training args
Executing the following command performs in-domain zero-shot
or cross-domain zero-shot
experiments by setting --test_dataset
.
python inference.py
--model_path {trained model} #trained model
--test_dataset {mscoco|flickr30k|nocaps} #target domain
The inference results file will be obtained in the same directory as the model weights file, using MSCOCO as an example.
trained_model
├── mscoco
| ├──captioner_{epoch}.pt #trained model
| ├──captioner_{epoch}_{test_dataset}_res.json #inference result
| └──...
Input the inference result file and the corresponding refs.json
to compute the validation metrics.
python eval_metrics.py
--candidates_json {inference result} #inference result
--references_json {reference file} #reference file