From 4f2f3752d970e0425a5f21defab91bc366265eb8 Mon Sep 17 00:00:00 2001 From: Yike Yuan <32432002+yyk-wew@users.noreply.github.com> Date: Tue, 1 Aug 2023 16:14:40 +0800 Subject: [PATCH] Support IconQA dataset. (#1670) --- mmpretrain/datasets/__init__.py | 3 +- mmpretrain/datasets/iconqa.py | 63 +++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 mmpretrain/datasets/iconqa.py diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index b7b6be47dce..1538a02c535 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -41,6 +41,7 @@ from .flickr30k_caption import Flickr30kCaption from .flickr30k_retrieval import Flickr30kRetrieval from .gqa_dataset import GQA + from .iconqa import IconQA from .nocaps import NoCaps from .ocr_vqa import OCRVQA from .refcoco import RefCOCO @@ -54,5 +55,5 @@ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA' + 'VSR', 'VizWiz', 'OCRVQA', 'IconQA' ]) diff --git a/mmpretrain/datasets/iconqa.py b/mmpretrain/datasets/iconqa.py new file mode 100644 index 00000000000..20c4d87ddea --- /dev/null +++ b/mmpretrain/datasets/iconqa.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import list_dir_or_file +from mmengine.utils import check_file_exist + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class IconQA(BaseDataset): + """IconQA: A benchmark for abstract diagram understanding + and visual language reasoning. + + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of the specific task and split. + eg. ``iconqa/val/choose_text/``. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, data_root: str, data_prefix: str, **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + **kwarg, + ) + + def load_data_list(self) -> List[dict]: + """Load data list.""" + sample_list = list( + list_dir_or_file(self.data_prefix['img_path'], list_file=False)) + + data_list = list() + for sample_id in sample_list: + # data json + # { + # "question": "How likely is it that you will pick a black one?", + # "choices": [ + # "certain", + # "unlikely", + # "impossible", + # "probable" + # ], + # "answer": 2, + # "ques_type": "choose_txt", + # "grade": "grade1", + # "label": "S2" + # } + data_info = mmengine.load( + mmengine.join_path(self.data_prefix['img_path'], sample_id, + 'data.json')) + data_info['gt_answer'] = data_info['choices'][int( + data_info['answer'])] + data_info['img_path'] = mmengine.join_path( + self.data_prefix['img_path'], sample_id, 'image.png') + check_file_exist(data_info['img_path']) + data_list.append(data_info) + + return data_list