diff --git a/lmms_eval/models/internvl2.py b/lmms_eval/models/internvl2.py index 5f4365d01..730493d9e 100644 --- a/lmms_eval/models/internvl2.py +++ b/lmms_eval/models/internvl2.py @@ -209,6 +209,14 @@ def generate_until(self, requests) -> List[str]: if k not in gen_kwargs: gen_kwargs[k] = v + pop_keys = [] + for k, v in gen_kwargs.items(): + if k not in DEFAULT_GEN_KWARGS: + pop_keys.append(k) + + for k in pop_keys: + gen_kwargs.pop(k) + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) if self.modality == "image": diff --git a/lmms_eval/tasks/muirbench/muirbench.yaml b/lmms_eval/tasks/muirbench/muirbench.yaml new file mode 100644 index 000000000..43b8ab7cc --- /dev/null +++ b/lmms_eval/tasks/muirbench/muirbench.yaml @@ -0,0 +1,41 @@ + +dataset_path: MUIRBENCH/MUIRBENCH +task: "muirbench" +dataset_kwargs: + token: True +test_split: test +output_type: generate_until +doc_to_visual: !function utils.muir_doc_to_visual +doc_to_text: !function utils.muir_doc_to_text +doc_to_target: !function utils.muir_doc_to_target +process_results: !function utils.muir_process_results + +model_specific_prompt_kwargs: + default: + pre_prompt: "" + post_prompt: "\nAnswer with the option's letter from the given choices directly." + + +generation_kwargs: + max_new_tokens: 16 + temperature: 0 + do_sample: False + +filter_list: + - name: "flexible-extract" + filter: + - function: !function utils.MultiChoiceRegexFilter + group_select: 0 + ignore_case: true + ignore_punctuation: true + regex_pattern: "([A-Z])\\." + +metric_list: + - metric: muirbench_score_overall + aggregation: !function utils.muir_aggregation + higher_is_better: true + ignore_case: true + ignore_punctuation: true + +metadata: + - version: 0.0 diff --git a/lmms_eval/tasks/muirbench/utils.py b/lmms_eval/tasks/muirbench/utils.py new file mode 100644 index 000000000..2924edac4 --- /dev/null +++ b/lmms_eval/tasks/muirbench/utils.py @@ -0,0 +1,117 @@ + +from lmms_eval.filters.extraction import ExtendedRegexFilter +from lmms_eval.filters.transformation import MapFilter +import re +import pandas as pd + + +def muir_doc_to_text(doc, model_specific_prompt_kwargs=None): + question, choices = doc["question"], doc["options"] + len_choices = len(choices) + post_prompt = model_specific_prompt_kwargs["post_prompt"] + pre_prompt = model_specific_prompt_kwargs["pre_prompt"] + options = [chr(ord("A") + i) for i in range(len_choices)] + choices_str = "\n".join([f"{option}. {choice}" for option, choice in zip(options, choices)]) + return f"{pre_prompt}{question}\n{choices_str}{post_prompt}" + + +def muir_doc_to_visual(doc): + image_list = [image.convert("RGB") for image in doc["image_list"]] + return image_list + + +def muir_doc_to_target(doc): + return doc["answer"] + + +def muir_process_results(doc, result): + pred = result[0] + task = doc["task"] + idx = doc["idx"] + image_relation = doc["image_relation"] + answer = doc["answer"] + image_type = doc["image_type"] + + data_dict = { + "pred" : pred, + "task" : task, + "idx" : idx, + "image_relation" : image_relation, + "answer" : answer, + "image_type" : image_type, + } + + return {"muirbench_score_overall" : data_dict} + + +def muir_aggregation(results): + task_num = {} + score = 0 + task_score = {} + for result in results: + if result["task"] not in task_score: + task_score[result["task"]] = 0 + + if result["task"] not in task_num: + task_num[result["task"]] = 0 + + if result["pred"].lower().strip() == result["answer"].lower().strip(): + task_score[result["task"]] += 1 + score += 1 + task_num[result["task"]] += 1 + + score = score / len(results) + + task_score = {k : v / task_num[k] for k,v in task_score.items()} + + print("=" * 50) + for k, v in task_score.items(): + print(f"{k} : {v:.2f}") + print("=" * 50) + + return score + + + + +class MultiChoiceRegexFilter(ExtendedRegexFilter): + def __init__(self, *args, **kwargs): + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(*args, **kwargs) + + def apply(self, resps, docs): + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + filtered_resps = [] + + for r, doc in zip(resps, docs): + # Regex to directly extract the option letter from the model response + option_letter_regex = re.compile(r"^\s*([A-Z])\.") + + # Process each response + filtered = [] + for resp in r: + # Try to match the option letter at the start of the response + match = option_letter_regex.match(resp) + if match: + # If a match is found, append the matched letter + filtered.append(match.group(1)) + else: + # If no match, return the original response + filtered.append(resp) + + # Assuming we need the first response that matches or the original response + filtered_resps.append(filtered[0]) + + return filtered_resps