This project aims to develop an efficient Extractive QA system using the Retrospective Reader model, focusing on improving accuracy and real-time performance for tasks like search engines and virtual assistants. It will be evaluated on benchmark datasets like SQuAD 2.0 to ensure effectiveness with both answerable and unanswerable questions.
- Install the required packages:
pip install -r requirements.txt
The model was trained on H100-96 GPUs on the SoC cluster. The training time was:
- Squad-v1.1: 60 minutes
- Squad-v2.0: 90 minutes
Notes:
- If the model is trained on Squad-v2.0, it can handle both answerable and unanswerable questions.
- If running on the SoC cluster, add the
srun
prefix. - The best hyperparameters are already configured:
- Epochs: 2
- Learning rate: 2e-5
python model.py --train --save_path "retro_reader_model.pth" --epochs 2 --lr 2e-5 --dataset "squad"
python model.py --train --save_path "retro_reader_model.pth" --epochs 1 --lr 2e-5 --dataset "squad_v2"
The testing process will generate a predictions.json
file, which can be evaluated using the evaluation script.
python model.py --test --model_path "retro_reader_model.pth" --dataset "squad"
python model.py --test --model_path "retro_reader_model.pth" --dataset "squad_v2"
The inference process will generate an answer given a question and context.
python model.py --inference --question "What is the Capital of France?" --context "Paris is the capital of France, it has many great architectures including the Eiffel tower" --model_path "retro_reader_model.pth" --dataset "squad"
python model.py --inference --question "What is the Capital of France?" --context "Paris is the capital of France, it has many great architectures including the Eiffel tower" --model_path "retro_reader_model.pth" --dataset "squad_v2"
- Generate the
predictions.json
file (this can be done using the test functionality of themodel.py
script). - Evaluate the model using the evaluation script.
python evaluate-v2.0.py dev-v1.1.json predictions.json
python evaluate-v2.0.py dev-v2.0.json predictions.json