Skip to content

Commit

Permalink
Merge pull request #8 from souradipp76/app
Browse files Browse the repository at this point in the history
Updating paper and README
  • Loading branch information
souradipp76 authored Oct 22, 2024
2 parents 3dcb997 + c05f30e commit ce109cd
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 72 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
[![CI](https://github.com/souradipp76/MM-PoE/actions/workflows/main.yml/badge.svg)](https://github.com/souradipp76/MM-PoE/actions/workflows/main.yml)


**Multiple Choice Reasoning via. Process of Elimination using Multi-Modal models**
**Multiple Choice Reasoning via. Process of Elimination using Multi-Modal Models**


## What is MM-PoE?

Multi-Modal Process of Elimination (MM-PoE) is a method to enhance vision language models' performance on multiple-choice visual reasoning by employing a two-step scoring system that first eliminates incorrect options and then predicts from the remaining ones. Our experiments across three question answering datasets show the method's effectiveness, particularly in visual reasoning tasks.

**Statement of Need**

Large Language models (LLMs) excel at in-context learning for multiple choice reasoning tasks but often treat all options equally, unlike humans who typically eliminate incorrect choices before selecting the correct answer. Same is true for vision language models (VLMs) in case of visual question answering tasks with multiple choices. This discrepancy can limit the effectiveness of vision language models in accurately solving such tasks. To address this, we introduce Multi-Modal Process of Elimination (MM-PoE), a two-step scoring method designed to enhance VLM performance by mimicking human reasoning strategies in multi-modal settings.

In the first step, the method evaluates and scores each option, systematically eliminating those that appear incorrect. The second step involves masking these eliminated options, allowing the VLM to focus solely on the remaining viable choices to make a final prediction. Our zero-shot experiments across three datasets demonstrate MM-PoE's effectiveness, particularly excelling in logical reasoning scenarios . Additionally, MM-PoE proves adaptable to few-shot settings and is compatible with the current state-of-the-art vision language models (VLMs).
In the first step, the method evaluates and scores each option, systematically eliminating those that appear incorrect. The second step involves masking these eliminated options, allowing the VLM to focus solely on the remaining viable choices to make a final prediction. Our zero-shot experiments across three datasets demonstrate MM-PoE's effectiveness, particularly excelling in logical reasoning scenarios. Additionally, MM-PoE proves adaptable to few-shot settings and is compatible with the current state-of-the-art vision language models (VLMs).

By implementing MM-PoE, researchers and practitioners can experiment and significantly improve the accuracy and reliability of VLMs in multiple choice reasoning tasks, making it a valuable tool for advancing machine learning models for visual reasoning.
Using this tool, researchers and practitioners can experiment and significantly improve the accuracy and reliability of VLMs in multiple choice reasoning tasks, making it a valuable tool for advancing machine learning models for visual reasoning.

## Installing MM-PoE

Expand Down Expand Up @@ -65,7 +63,10 @@ $ python -m mm_poe
$ mm_poe
```

The application will prompt the user to provide relevant inputs for a multiple choice question e.g a question, multiple answer choices for the question and the path to the image relevant the question context. Once the inputs are provided, the predicted answer will be displayed based on the selections. Note that this application runs inference for only a single sample at a time.
The application will prompt the user to provide relevant inputs for a multiple choice question e.g. a question, multiple answer choices for the question and the path to the image relevant the question context. Once the inputs are provided, the predicted answer will be displayed based prompt outputs. Note that this application runs inference for only a single sample at a time.


<img src="paper/figures/cli.png" alt="Example" width="500">

### Running Experiments

Expand Down
7 changes: 4 additions & 3 deletions mm_poe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main():
).ask()

args.loading_precision = questionary.select(
message="Select model checkpoint?",
message="Select model precision?",
choices=["FP32", "FP16", "BF16", "INT8", "INT4"],
default="FP32",
).ask()
Expand Down Expand Up @@ -116,7 +116,8 @@ def main():
"Image Path?", default="./images/image.png"
).ask()
args.label = questionary.select(
message="Answer:", choices=[str(x) for x in range(args.num_options)]
message="Ground Truth Option:",
choices=[str(x) for x in range(args.num_options)],
).ask()
args.label = int(args.label)
args.method = "process_of_elimination"
Expand Down Expand Up @@ -394,4 +395,4 @@ def main():
)
)
option = int(lm_predictions.numpy()[0])
logger.info(f"Answer: {option}")
logger.info(f"Predicted Option: {option}. Answer: {args.choices[option]}")
3 changes: 3 additions & 0 deletions mm_poe/data/data_downloaders.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ cd Annotations
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Train_mscoco.zip
unzip Annotations_Train_mscoco.zip
rm Annotations_Train_mscoco.zip
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Annotations_Val_mscoco.zip
unzip Annotations_Val_mscoco.zip
rm Annotations_Val_mscoco.zip
mkdir ../Questions
cd ../Questions
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/Questions_Train_mscoco.zip
Expand Down
17 changes: 8 additions & 9 deletions mm_poe/methods/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,6 @@ def preprocess_function_causal_vqa_channel(examples, **kwargs):
tokenizer = processor.tokenizer
image_processor = processor.image_processor

ending_names = [k for k in examples.keys() if k.startswith("hypothesis")]
num_choice = len(ending_names)
question_headers = examples[header_name]
first_sentences = [
Expand Down Expand Up @@ -1263,24 +1262,24 @@ def vqa_loader(path, args):
examples = []

print("Loading annotations and questions...")
train_anno = json.load(open(ann_file, "r"))
train_ques = json.load(open(question_file, "r"))
anno = json.load(open(ann_file, "r"))
ques = json.load(open(question_file, "r"))

if args.calibration_prompt is not None:
uncond_premise = args.calibration_prompt
else:
uncond_premise = " the answer is:"

for i in range(len(train_anno["annotations"])):
ans = train_anno["annotations"][i]["multiple_choice_answer"]
img_id = train_anno["annotations"][i]["image_id"]
for i in range(len(anno["annotations"])):
ans = anno["annotations"][i]["multiple_choice_answer"]
img_id = anno["annotations"][i]["image_id"]
# question_id = train_anno['annotations'][i]['question_id']
image_path = os.path.join(
img_dir, "COCO_train2014_" + "%012d.jpg" % img_id
img_dir, "COCO_%s2014_" % args.split + "%012d.jpg" % img_id
)

question = train_ques["questions"][i]["question"]
mc_ans = train_ques["questions"][i]["multiple_choices"]
question = ques["questions"][i]["question"]
mc_ans = ques["questions"][i]["multiple_choices"]
label = mc_ans.index(ans)

if getattr(args, "multiple_choice_prompt", None) is not None:
Expand Down
4 changes: 2 additions & 2 deletions mm_poe/methods/utils/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def inference_language_modeling(
labels
)
pbar.set_description(
f"Language modeling accuracy: {lm_accuracy:.4f},\
Average language modeling accuracy: {avg_lm_accuracy:.4f}"
f"Language modeling accuracy: {lm_accuracy:.4f}, "
+ f"Average language modeling accuracy: {avg_lm_accuracy:.4f}"
)
avg_log_probs = torch.cat(avg_log_probs, dim=0)
return avg_log_probs, lm_accuracy, avg_lm_accuracy, lm_predictions
Expand Down
File renamed without changes.
File renamed without changes.
45 changes: 45 additions & 0 deletions mm_poe/results/language_modeling.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
model_family,checkpoint,loading_precision,dataset,batch_size,method,seed,n_shot,sample,accuracy
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,0,0,100,0.2500
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,1,0,100,0.2400
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,2,0,100,0.2700
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,3,0,100,0.2300
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,4,0,100,0.2800
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,0,0,100,0.2900
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,language_modeling,0,0,100,0.2600
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,0,0,100,0.2600
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,language_modeling,1,0,100,0.1300
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,1,0,100,0.2300
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,language_modeling,2,0,100,0.2500
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,2,0,100,0.2600
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,language_modeling,3,0,100,0.2500
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,3,0,100,0.2400
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,language_modeling,4,0,100,0.2000
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,language_modeling,4,0,100,0.3100
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,language_modeling,0,0,100,0.3000
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,language_modeling,1,0,100,0.2400
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,language_modeling,2,0,100,0.2800
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,language_modeling,3,0,100,0.3100
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,language_modeling,4,0,100,0.2400
GIT,microsoft/git-base-vqav2,FP32,vqa,2,language_modeling,0,0,100,0.6000
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,average_language_modeling,0,0,100,0.1600
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,average_language_modeling,1,0,100,0.2000
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,average_language_modeling,2,0,100,0.2200
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,average_language_modeling,3,0,100,0.1800
GIT,microsoft/git-base-vqav2,FP32,scienceqa,2,average_language_modeling,4,0,100,0.1300
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,average_language_modeling,0,0,100,0.3000
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,average_language_modeling,1,0,100,0.2600
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,average_language_modeling,2,0,100,0.2200
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,average_language_modeling,3,0,100,0.2600
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,average_language_modeling,4,0,100,0.2700
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,average_language_modeling,0,0,100,0.1900
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,average_language_modeling,0,0,100,0.3100
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,average_language_modeling,1,0,100,0.2000
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,average_language_modeling,1,0,100,0.2800
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,average_language_modeling,2,0,100,0.2100
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,average_language_modeling,2,0,100,0.2300
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,average_language_modeling,3,0,100,0.2200
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,average_language_modeling,3,0,100,0.2800
GIT,microsoft/git-base-textvqa,FP32,scienceqa,2,average_language_modeling,4,0,100,0.2000
GIT,microsoft/git-base-textvqa,FP32,ai2d,2,average_language_modeling,4,0,100,0.2800
GIT,microsoft/git-base-textvqa,FP32,vqa,2,language_modeling,0,0,100,0.1900
GIT,microsoft/git-base-textvqa,FP32,vqa,2,average_language_modeling,0,0,100,0.1800
12 changes: 12 additions & 0 deletions mm_poe/results/language_modeling_old.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_family,checkpoint,loading_precision,dataset,batch_size,method,seed,n_shot,sample,accuracy
BLIP2,Salesforce/blip2-opt-2.7b,INT8,vqa,16,language_modeling,0,0,100,0.0600
BLIP2,Salesforce/blip2-opt-2.7b,INT8,vqa,16,language_modeling,0,0,100,0.0600
BLIP2,Salesforce/blip2-opt-2.7b,INT4,vqa,4,language_modeling,0,0,100,0.4300
BLIP2,Salesforce/blip2-opt-2.7b,FP16,scienceqa,2,language_modeling,0,0,100,0.3800
BLIP2,Salesforce/blip2-opt-2.7b,FP16,scienceqa,2,language_modeling,0,0,100,0.3800
BLIP2,Salesforce/blip2-opt-2.7b,BF16,scienceqa,2,language_modeling,0,0,100,0.2200
BLIP2,Salesforce/blip2-flan-t5-xl,FP16,scienceqa,2,language_modeling,0,0,100,0.3600
BLIP2,Salesforce/blip2-opt-2.7b,FP16,ai2d,2,language_modeling,0,0,100,0.2300
PaliGemma,google/paligemma-3b-ft-ai2d-448,FP16,ai2d,2,language_modeling,0,0,100,0.2100
PaliGemma,google/paligemma-3b-ft-ai2d-448,FP16,ai2d,2,language_modeling,0,0,100,0.2100
GIT,microsoft/git-base-vqav2,FP32,ai2d,2,language_modeling,0,0,100,0.2100
File renamed without changes.
File renamed without changes.
23 changes: 0 additions & 23 deletions mm_poe/results/vision_language_modeling.csv

This file was deleted.

12 changes: 0 additions & 12 deletions mm_poe/results/vision_language_modeling1.csv

This file was deleted.

Binary file added paper/figures/17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paper/figures/cli.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions paper/paper.bib
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,26 @@ @conj{Idefics2
version = {8b},
howpublished = {\url{https://huggingface.co/HuggingFaceM4/idefics2-8b}}
}

@InProceedings{VQA,
author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh},
title = {VQA: Visual Question Answering},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2015},
}

@article{Kembhavi2016ADI,
title={A Diagram is Worth a Dozen Images},
author={Aniruddha Kembhavi and Michael Salvato and Eric Kolve and Minjoon Seo and Hannaneh Hajishirzi and Ali Farhadi},
journal={ArXiv},
year={2016},
volume={abs/1603.07396},
url={https://api.semanticscholar.org/CorpusID:2682274}
}

@inproceedings{lu2022learn,
title={Learn to Explain: Multimodal Reasoning via Thought Chains for Science Question Answering},
author={Lu, Pan and Mishra, Swaroop and Xia, Tony and Qiu, Liang and Chang, Kai-Wei and Zhu, Song-Chun and Tafjord, Oyvind and Clark, Peter and Ashwin Kalyan},
booktitle={The 36th Conference on Neural Information Processing Systems (NeurIPS)},
year={2022}
}
Loading

0 comments on commit ce109cd

Please sign in to comment.