diff --git a/pyproject.toml b/pyproject.toml index 1d6c11ec..7ff2fca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,8 @@ quality = ["ruff==v0.2.2","pre-commit"] tests = ["pytest==7.4.0"] dev = ["lighteval[accelerate,quality,tests]"] extended_tasks = [ - "langdetect", #ifeval + "langdetect", # ifeval + "openai", # mt-bench ] [project.urls] diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index d74661c7..5070261c 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -88,6 +88,8 @@ def evaluate( # noqa: C901 full_resps = lm.greedy_until_with_logits(requests, override_bs=override_bs) elif request_type == RequestType.LOGLIKELIHOOD_ROLLING: full_resps = lm.loglikelihood_rolling(requests, override_bs=override_bs) + elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN: + full_resps = lm.greedy_until_multi_turn(requests, override_bs=override_bs) else: raise NotImplementedError(f"Request type {request_type} not supported") @@ -115,8 +117,22 @@ def evaluate( # noqa: C901 # using a deep copy here because process results pops from the model responses metrics = task.process_results(doc, copy.deepcopy(model_responses)) + # Remove the user_prompt from the metrics in case of llm-as-judge metric + if "user_prompt" in metrics: + user_prompt = metrics["user_prompt"] + del metrics["user_prompt"] + else: + user_prompt = None + if "judgement" in metrics: + judgement = metrics["judgement"] + del metrics["judgement"] + else: + judgement = None + evaluation_tracker.metrics_logger.log(task_example_id.task_name, metrics) - evaluation_tracker.details_logger.log(task_example_id.task_name, task, doc, model_responses, metrics) + evaluation_tracker.details_logger.log( + task_example_id.task_name, task, doc, model_responses, metrics, (user_prompt, judgement) + ) return evaluation_tracker diff --git a/src/lighteval/few_shot_manager.py b/src/lighteval/few_shot_manager.py index 984685f3..081703a8 100644 --- a/src/lighteval/few_shot_manager.py +++ b/src/lighteval/few_shot_manager.py @@ -27,7 +27,7 @@ from itertools import cycle from typing import TYPE_CHECKING, Optional -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer from lighteval.logging.hierarchical_logger import hlog_warn from lighteval.tasks.requests import Doc @@ -219,6 +219,46 @@ def get_examples( ) return instruction + labeled_examples + example + def create_multi_turn_contexts( + self, doc: Doc, use_chat_template: bool, system_prompt: Optional[str], tokenizer: PreTrainedTokenizer + ) -> list[str]: + """Creates N contexts (depending on the number of turn) for a tasks. + Multi turn tasks need use chat templating. + + Args: + doc (Doc): Formated document. + use_chat_template (bool): wether or not to use chat template. Will fail if false. + system_prompt (Optional[str]): The system prompt to use + tokenizer (PreTrainedTokenizer): The tokenizer used for the chat template + + Raises: + ValueError: If use_chat_template is set to false. + + Returns: + list[str]: contexts for every turn + """ + if not use_chat_template: + raise ValueError("You need to use the chat template to create multi turn contexts") + + role_content_list = [] + if system_prompt is not None: + role_content_list.append({"role": "system", "content": system_prompt}) + + for i in doc.specific["multi_turn_queries"]: + role_content_list.append({"role": "user", "content": i}) + role_content_list.append({"role": "assistant", "content": "{model_response}"}) + role_content_list.pop(-1) + + contexts = [] + offset = 2 if system_prompt is not None else 1 + for i in range(0, len(role_content_list), offset + 1): + c = tokenizer.apply_chat_template( + role_content_list[: i + offset], add_generation_prompt=True, tokenize=False, add_special_tokens=False + ) + contexts.append(c) + + return contexts, 0 + def fewshot_context( self, task: "LightevalTask", diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index 0caabaa8..b11c124c 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -24,7 +24,7 @@ import os import time from dataclasses import asdict, dataclass, field -from typing import Union +from typing import Optional, Union import git import numpy as np @@ -205,6 +205,9 @@ class Detail: choices: list = field(default_factory=list) gold_index: list = field(default_factory=list) metrics: dict = field(default_factory=dict) + judement_prompt: str = None + judgement: str = None + specifics: dict = field(default_factory=dict) @dataclass class CompiledDetail: @@ -302,7 +305,15 @@ class CompiledHash: compiled_details: dict[str, CompiledDetail] = collections.defaultdict(CompiledDetail) compiled_details_over_all_tasks: CompiledDetailOverAllTasks = CompiledDetailOverAllTasks() - def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[ModelReturn], metrics: dict) -> None: + def log( + self, + task_name: str, + task: LightevalTask, + doc: Doc, + outputs: list[ModelReturn], + metrics: dict, + llm_as_prompt_judgement: Optional[tuple[str, str]] = None, + ) -> None: """Stores the relevant information for one sample of one task to the total list of samples stored in the DetailsLogger. Args: @@ -311,6 +322,8 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model doc (Doc): Current sample that we want to store. outputs (list[ModelReturn]): Model outputs for the current sample metrics (_type_): Model scores for said sample on the current task's metrics. + llm_as_prompt_judgement (tuple[str, str]): Tuple containing the + prompt passed to the judge and the judgement for the current sample when using llm-as-judge metric. """ detail = self.Detail() detail.example = doc.query @@ -354,6 +367,11 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model detail.choices = doc.choices detail.gold_index = as_list(doc.gold_index) pred_saved = True + if task.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]: + pred_saved = True + detail.judement_prompt = llm_as_prompt_judgement[0] + detail.judgement = llm_as_prompt_judgement[1] + detail.specifics = doc.specific if not pred_saved: raise NotImplementedError( "No metric prediction saved." @@ -364,7 +382,7 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model hash = self.Hash() hash.example = xxhash.xxh64(doc.query).hexdigest() - hash.full_prompt = xxhash.xxh64(doc.ctx).hexdigest() + hash.full_prompt = xxhash.xxh64(str(doc.ctx)).hexdigest() hash.input_tokens = xxhash.xxh64(str([o.input_tokens for o in outputs])).hexdigest() hash.cont_tokens = xxhash.xxh64(str([o.generated_tokens for o in outputs])).hexdigest() self.hashes[task_name].append(hash) diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index f7e52035..3dfd0ca9 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -146,3 +146,14 @@ def apply_multichoice_metric_one_token(results: list[ModelReturn], formatted_doc ) return results, outputs + + +def apply_generative_multi_turn_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]): + outputs = {} + predictions = results.pop(0).result + + for metric in metrics: + if Metrics[metric].value.category == MetricCategory.GENERATIVE_MULTI_TURN: + outputs.update(Metrics[metric].value.compute(predictions=predictions, formatted_doc=formatted_doc)) + + return results, outputs diff --git a/src/lighteval/metrics/utils.py b/src/lighteval/metrics/utils.py index d0a410f2..eb1585e6 100644 --- a/src/lighteval/metrics/utils.py +++ b/src/lighteval/metrics/utils.py @@ -28,6 +28,7 @@ class MetricCategory(Enum): TARGET_PERPLEXITY = auto() PERPLEXITY = auto() GENERATIVE = auto() + GENERATIVE_MULTI_TURN = auto() GENERATIVE_LOGPROB = auto() MULTICHOICE = auto() MULTICHOICE_ONE_TOKEN = auto() diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index be88a6b2..ccc49146 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -27,8 +27,14 @@ from transformers import BatchEncoding from lighteval.models.model_config import EnvConfig -from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn +from lighteval.models.model_output import ( + GenerateMultiTurnReturn, + GenerateReturn, + LoglikelihoodReturn, + LoglikelihoodSingleTokenReturn, +) from lighteval.tasks.requests import ( + GreedyUntilMultiTurnRequest, GreedyUntilRequest, GreedyUntilWithLogitsRequest, LoglikelihoodRequest, @@ -102,6 +108,12 @@ def greedy_until_with_logits( returns_logits=True, ) + def greedy_until_multi_turn( # noqa: C901 + self, requests: list[GreedyUntilMultiTurnRequest], override_bs: Optional[int] = None + ) -> GenerateMultiTurnReturn: + """Generates responses using a greedy decoding strategy until certain ending conditions are met.""" + return NotImplemented + @abstractmethod def greedy_until( self, diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index b551c08f..e5545c36 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -35,9 +35,16 @@ from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_config import BaseModelConfig, EnvConfig -from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn -from lighteval.models.utils import _get_dtype, _get_precision, _simplify_name +from lighteval.models.model_output import ( + Batch, + GenerateMultiTurnReturn, + GenerateReturn, + LoglikelihoodReturn, + LoglikelihoodSingleTokenReturn, +) +from lighteval.models.utils import _get_dtype, _get_precision, _simplify_name, batched from lighteval.tasks.requests import ( + GreedyUntilMultiTurnRequest, GreedyUntilRequest, GreedyUntilWithLogitsRequest, LoglikelihoodRequest, @@ -345,6 +352,137 @@ def greedy_until_with_logits( override_bs=override_bs, ) + def greedy_until_multi_turn( # noqa: C901 + self, requests: list[GreedyUntilMultiTurnRequest], override_bs: Optional[int] = None + ) -> GenerateMultiTurnReturn: + for request in requests: + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] + request.tokenized_context = self.tok_encode(request.context)["input_ids"] + + results = [] + + dataset = GenerativeTaskDataset(requests=requests, dataset_splits=1) + dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch) + + if self.accelerator: + dataloader = self.accelerator.prepare(dataloader) + + hlog_warn("Running greedy multi turn generation, the batch size is set to 1 for this task.") + + for request_batch in tqdm( + dataloader, desc="Greedy Multi Turn generation", position=1, leave=False, disable=self.disable_tqdm + ): + request = request_batch[0] + stop_tokens = request.stop_sequence + max_generated_tokens = request.generation_size + context = request.context[0] + max_context_size_allowed = self.max_length - max_generated_tokens + + model_inputs = self.tokenizer( + context, + padding=True, + truncation=True, + return_tensors="pt", + max_length=max_context_size_allowed, + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + stopping_criteria = transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, self.tokenizer, input_ids_shape=model_inputs["input_ids"].shape + ) + for sequence in stop_tokens + ], + ] + ) + model_outputs = self.model.generate( + **model_inputs, + max_new_tokens=max_generated_tokens, + stopping_criteria=stopping_criteria, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id + else self.tokenizer.eos_token_id, + ) + model_outputs = model_outputs[0, model_inputs["input_ids"].size(1) :] + model_generations = [model_outputs] + decoded_generation = self.tokenizer.decode(model_outputs) + for term in stop_tokens: + decoded_generation = decoded_generation.split(term)[0] + + input_tokens = [model_inputs["input_ids"]] + + for i, multi_turn_context in enumerate(request.context[1:]): + multi_turn_context = multi_turn_context.format(model_response=decoded_generation) + + model_inputs = self.tokenizer( + multi_turn_context, + padding=True, + truncation=True, + return_tensors="pt", + max_length=max_context_size_allowed, + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + stopping_criteria = transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, self.tokenizer, input_ids_shape=model_inputs["input_ids"].shape + ) + for sequence in stop_tokens + ], + ] + ) + + model_outputs = self.model.generate( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + max_new_tokens=max_generated_tokens, + stopping_criteria=stopping_criteria, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id + else self.tokenizer.eos_token_id, + ) + model_outputs = model_outputs[0, model_inputs["input_ids"].size(1) :] + model_generations.append(model_outputs) + decoded_generation = self.tokenizer.decode(model_outputs, skip_special_tokens=True) + input_tokens.append(model_inputs["input_ids"]) + + for term in stop_tokens: + decoded_generation = decoded_generation.split(term)[0] + + if self.accelerator: + padding_size = max(gen.shape[0] for gen in model_generations) + for i, gen in enumerate(model_generations): + model_generations[i] = F.pad( + gen, (0, padding_size - gen.shape[0]), value=self.tokenizer.pad_token_id + ) + model_generations = torch.stack(model_generations, dim=0) + model_generations, lengths = self.pad_and_gather(model_generations, drop_last_samples=False) + + model_answers = [] + for generation, _ in zip(model_generations, lengths): + generation = generation.cpu().tolist() + decoded = self.tokenizer.decode(generation, skip_special_tokens=True) + model_answers.append(decoded) + + for answers in batched(model_answers, len(request.context)): + results.append( + GenerateMultiTurnReturn( + result=answers, + input_tokens=[], + generated_tokens=[], + truncated_tokens_count=0, + padded_tokens_count=0, + ) + ) + + return results + def greedy_until( self, requests: list[GreedyUntilRequest], @@ -753,9 +891,20 @@ def prepare_batch_logprob( padded=padded, ) - def pad_and_gather(self, output_tensor: torch.Tensor) -> torch.Tensor: - """Gather together tensors of (possibly) various size spread on separate GPUs (first exchange the lengths and then pad and gather)""" - # Create a tensor of size batch_size, [output_length] * batch_size, for each each process + def pad_and_gather(self, output_tensor: torch.Tensor, drop_last_samples: bool = True) -> torch.Tensor: + """ + Pads the `output_tensor` to the maximum length and gathers the lengths across processes. + + Args: + output_tensor (torch.Tensor): The output tensor to be padded. + drop_last_samples (bool, optional): Whether to drop the last samples during gathering. + Last samples are dropped when the number of samples is not divisible by the number of processes. + Defaults to True. + + Returns: + torch.Tensor: The padded output tensor and the gathered length tensor. + """ + # Create a tensor of size batch_size, [output_length] * batch_size, for each process length_tensor = torch.tensor([output_tensor.shape[1]] * output_tensor.shape[0], device=self.device) if self.accelerator is not None: # Gather all the lengths, we end up with a tensor of size num_processes [output_length_1, output_length_2, ...] @@ -766,7 +915,10 @@ def pad_and_gather(self, output_tensor: torch.Tensor) -> torch.Tensor: output_tensor, (0, max_length - output_tensor.shape[1], 0, 0), value=self.tokenizer.pad_token_id ) if self.accelerator: - output_tensor = self.accelerator.gather_for_metrics(output_tensor) + if drop_last_samples: + output_tensor = self.accelerator.gather_for_metrics(output_tensor) + else: + output_tensor = self.accelerator.gather(output_tensor) return output_tensor, length_tensor def loglikelihood_single_token( @@ -891,10 +1043,15 @@ def __init__( self, sequence: str, tokenizer: transformers.PreTrainedTokenizer, - batch: Batch, + batch: Batch = None, + input_ids_shape: Tuple[int, int] = None, ): - initial_decoder_input_length = batch.input_ids.shape[1] - batch_size = batch.input_ids.shape[0] + if batch is not None: + initial_decoder_input_length = batch.input_ids.shape[1] + batch_size = batch.input_ids.shape[0] + else: + initial_decoder_input_length = input_ids_shape[1] + batch_size = input_ids_shape[0] self.initial_decoder_input_length = initial_decoder_input_length self.done_tracker = [False] * batch_size diff --git a/src/lighteval/models/model_output.py b/src/lighteval/models/model_output.py index fae528a8..51027858 100644 --- a/src/lighteval/models/model_output.py +++ b/src/lighteval/models/model_output.py @@ -66,6 +66,14 @@ def get_result_for_eval(self): return self.result if self.logits is None else (self.result, self.logits) +@dataclass +class GenerateMultiTurnReturn(ModelReturn): + result: list[str] = field(default_factory=list) + + def get_result_for_eval(self): + return self.result + + @dataclass class Batch: input_ids: torch.Tensor diff --git a/src/lighteval/models/utils.py b/src/lighteval/models/utils.py index 9c66f29f..ba968151 100644 --- a/src/lighteval/models/utils.py +++ b/src/lighteval/models/utils.py @@ -21,6 +21,7 @@ # SOFTWARE. import os +from itertools import islice from typing import TYPE_CHECKING, Optional, Union import torch @@ -113,3 +114,12 @@ def _get_model_sha(repo_id: str, revision: str): return model_info.sha except Exception: return "" + + +def batched(iterable, n): + # batched('ABCDEFG', 3) → ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch diff --git a/src/lighteval/tasks/extended/__init__.py b/src/lighteval/tasks/extended/__init__.py index 201c8c4d..81919c0a 100644 --- a/src/lighteval/tasks/extended/__init__.py +++ b/src/lighteval/tasks/extended/__init__.py @@ -25,9 +25,10 @@ if can_load_extended_tasks(): import lighteval.tasks.extended.ifeval.main as ifeval + import lighteval.tasks.extended.mt_bench.main as mt_bench import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks - AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks] + AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench] else: AVAILABLE_EXTENDED_TASKS_MODULES = [] diff --git a/src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl b/src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl new file mode 100644 index 00000000..86854fff --- /dev/null +++ b/src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl @@ -0,0 +1,8 @@ +{"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"} +{"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} +{"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} \ No newline at end of file diff --git a/src/lighteval/tasks/extended/mt_bench/judges.py b/src/lighteval/tasks/extended/mt_bench/judges.py new file mode 100644 index 00000000..a75d4eac --- /dev/null +++ b/src/lighteval/tasks/extended/mt_bench/judges.py @@ -0,0 +1,222 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Inspired by the FastChat Codebase: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md + + +import ast +import json +import re +import time +from abc import ABC +from typing import Optional + +from openai import OpenAI + +from lighteval.logging.hierarchical_logger import hlog_warn + + +# Abstract class for a judge +class Judge(ABC): + def evaluate_answer(answers, questions, references) -> tuple[str, list[dict[str, str]], str]: + pass + + +class JudgeOpenAI(Judge): + """ + A class representing a judge for evaluating answers using the OpenAI API. + + Args: + model (str): The name of the OpenAI model to use. + seed (int): The seed value for generating random responses. + temperature (float): The temperature value for controlling the randomness of the responses. + templates_path (str): The path to the JSON file containing the templates for prompts. + + Attributes: + client: An instance of the OpenAI client. + model (str): The name of the OpenAI model. + seed (int): The seed value, passed to the API when generating responses. + temperature (float): The temperature value, passed to the API when generating responses. + templates (dict): A dictionary containing the templates for prompts. + one_score_pattern (re.Pattern): A regular expression pattern for extracting scores from the response. + one_score_pattern_backup (re.Pattern): A backup regular expression pattern for extracting scores. + API_MAX_RETRY (int): The maximum number of API retries. + API_RETRY_SLEEP (int): The sleep time between API retries. + max_tokens (int): The maximum number of tokens allowed in the response. + + Methods: + evaluate_answer: Evaluates an answer using the OpenAI API. + __get_prompts_multi_turn: Generates prompts for multi-turn conversations. + __get_prompts_single_turn: Generates prompts for single-turn conversations. + __process_judge_response: Processes the judge's response and extracts the score. + """ + + def __init__(self, model: str, seed: int, temperature: float, templates_path: str, openai_api_key: str): + self.client = OpenAI(api_key=openai_api_key) + self.model = model + self.seed = seed + self.temperature = temperature + + data = [] + with open(templates_path, "r") as f: + for line in f: + tmp = json.loads(line) + data.append(tmp) + + self.templates = {d["name"]: d for d in data} + + # Patterns for extracting scores from the response + # The first pattern is for the default case: [[score]], + # the second is for the backup case: [score] + self.one_score_pattern = re.compile(r"\[\[(\d+\.?\d*)\]\]") + self.one_score_pattern_backup = re.compile(r"\[(\d+\.?\d*)\]") + + self.API_MAX_RETRY = 16 + self.API_RETRY_SLEEP = 10 + self.max_tokens = 2048 + + def evaluate_answer( + self, questions: list[str], answers: list[str], references: list[str], single_turn: bool + ) -> tuple[int, list[dict[str, str]], str]: + """ + Evaluates an answer using the OpenAI API. + + Args: + questions (list[str]): A list of questions (can be a list because of multi-turn conversations) + answers (list[str]): A list of answers, one for each question. + references (list[str]): A list of reference answers, one for each question (sometimes not available) + single_turn (bool): Indicates whether the conversation is single-turn or multi-turn. + + Returns: + A tuple containing the score, prompts, and judgment. + + Raises: + Exception: If an error occurs during the API call. + """ + if single_turn: + prompts = self.__get_prompts_single_turn( + questions[0], answers[0], references[0] if len(references) > 0 else None + ) + else: + prompts = self.__get_prompts_multi_turn(questions, answers, references if len(references) > 1 else None) + + for _ in range(self.API_MAX_RETRY): + try: + response = self.client.chat.completions.create( + model=self.model, + seed=self.seed, + temperature=self.temperature, + messages=prompts, + max_tokens=self.max_tokens, + n=1, + ) + break + except Exception as e: + hlog_warn(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) + response = None + + if response is None: + raise Exception("Failed to get response from the API") + + judgment = response.choices[0].message.content + score = self.__process_judge_response(judgment) + + return score, prompts, judgment + + def __get_prompts_multi_turn( + self, questions: list[str], answers: list[str], references: Optional[list[str]] + ) -> list[dict[str, str]]: + """ + Generates prompts for multi-turn conversations. The prompts are generated based on the templates. + The prompt is different for the case where reference answers are available. + + Args: + questions (list[str]): A list of questions. + answers (list[str]): A list of answers. + references (Optional[list[str]]): A list of reference answers. + + Returns: + A list of prompts. + """ + if references is None: + system_prompt = {"role": "system", "content": self.templates["single-v1-multi-turn"]["system_prompt"]} + user_prompt_str = self.templates["single-v1-multi-turn"]["prompt_template"].format( + question_1=questions[0], answer_1=answers[0], question_2=questions[1], answer_2=answers[1] + ) + else: + system_prompt = {"role": "system", "content": self.templates["single-math-v1-multi-turn"]["system_prompt"]} + user_prompt_str = self.templates["single-math-v1-multi-turn"]["prompt_template"].format( + question_1=questions[0], + answer_1=answers[0], + ref_answer_1=references[0], + question_2=questions[1], + answer_2=answers[1], + ref_answer_2=references[1], + ) + user_prompt = {"role": "user", "content": user_prompt_str} + return [system_prompt, user_prompt] + + def __get_prompts_single_turn(self, question: str, answer: str, reference: Optional[str]) -> list[dict[str, str]]: + """ + Generates prompts for single-turn conversations. The prompts are generated based on the templates. + The prompt is different for the case where a reference answer is available. + + Args: + question (str): The question. + answer (str): The answer. + reference (Optional[str]): The reference answer. + + Returns: + A list of prompts. + """ + if reference is None: + system_prompt = {"role": "system", "content": self.templates["single-v1"]["system_prompt"]} + user_prompt_str = self.templates["single-v1"]["prompt_template"].format(question=question, answer=answer) + else: + system_prompt = {"role": "system", "content": self.templates["single-math-v1"]["system_prompt"]} + user_prompt_str = self.templates["single-math-v1"]["prompt_template"].format( + question=question, answer=answer, ref_answer_1=reference + ) + user_prompt = {"role": "user", "content": user_prompt_str} + return [system_prompt, user_prompt] + + def __process_judge_response(self, judgment: str) -> int: + """ + Processes the judge's response and extracts the score. + Returns -1 if the score cannot be extracted. + + Args: + judgment (str): The judge's response. + + Returns: + The extracted score. + """ + match = re.search(self.one_score_pattern, judgment) + if not match: + match = re.search(self.one_score_pattern_backup, judgment) + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = -1 + + return rating diff --git a/src/lighteval/tasks/extended/mt_bench/main.py b/src/lighteval/tasks/extended/mt_bench/main.py new file mode 100644 index 00000000..e2c59511 --- /dev/null +++ b/src/lighteval/tasks/extended/mt_bench/main.py @@ -0,0 +1,135 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# ruff: noqa: F405, F403, F401, I001 + +import numpy as np +from aenum import extend_enum +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lighteval.tasks.extended.mt_bench.judges import JudgeOpenAI +from lighteval.metrics import Metrics +from lighteval.metrics.utils import MetricCategory, MetricUseCase, SampleLevelMetric, SampleLevelMetricGrouping +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES +from colorama import Fore, Style +import os + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +if OPENAI_API_KEY is None: + # Using print here because hlog_warn is not yet available in this context + print( + Fore.YELLOW + + "No OpenAI API key found. If you are using the OpenAI judge, please set the OPENAI_API_KEY environment variable." + + Style.RESET_ALL + ) + +task = LightevalTaskConfig( + name="mt_bench", + prompt_function="prompt_fn", # must be defined in the file or imported from src/lighteval/tasks/tasks_prompt_formatting.py + suite=["extended"], + hf_repo="lighteval/mt-bench", + hf_subset="default", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split="", + few_shots_select="random", + metric=["mt_bench_metric"], + generation_size=1024, + stop_sequence=[], +) + + +def prompt_fn(line, task_name: str = None): + """Defines how to go from a dataset line to a doc object. + Follow examples in src/lighteval/tasks/tasks_prompt_formatting.py, or get more info + about what this function should do in the README. + """ + return Doc( + task_name=task_name, + query=f"{line['turns'][0]}", + choices=None, + instruction=None, + gold_index=[], + specific={ + "reference": line["reference"], + "category": line["category"], + "multi_turn_queries": line["turns"], + "id": line["question_id"], + }, + ) + + +def mt_bench_metric(predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]: + """Defines how to go from a list of predictions to a score. + Follow examples in src/lighteval/metrics/metrics.py, or get more info + about what this function should do in the README. + """ + + judge = JudgeOpenAI( + model="gpt-3.5-turbo", + seed=42, + temperature=0.0, + templates_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl", + openai_api_key=OPENAI_API_KEY, + ) + + questions = formatted_doc.specific["multi_turn_queries"] + ref_answers = formatted_doc.specific["reference"] + + score, messages, judgement = judge.evaluate_answer(questions, predictions, ref_answers, single_turn=True) + score_mt, messages_mt, judgement_mt = judge.evaluate_answer(questions, predictions, ref_answers, single_turn=False) + + return { + "single_turn": score, + "multi_turn": score_mt, + "user_prompt": [messages, messages_mt], + "judgement": [judgement, judgement_mt], + } + + +mt_bench_metric = SampleLevelMetricGrouping( + metric="mt_bench_metric", + higher_is_better=True, + category=MetricCategory.GENERATIVE_MULTI_TURN, + use_case=MetricUseCase.SUMMARIZATION, + sample_level_fn=mt_bench_metric, + corpus_level_fn={ + "single_turn": np.mean, + "multi_turn": np.mean, + }, +) + +_TASKS = [task] + +TASKS_TABLE = [task.as_dict() for task in _TASKS] +extend_enum( + Metrics, + "mt_bench_metric", + mt_bench_metric, +) + +if __name__ == "__main__": + print(t["name"] for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 41eca066..c0208cd4 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -34,6 +34,7 @@ from lighteval.metrics import ( apply_generative_logprob_metric, apply_generative_metric, + apply_generative_multi_turn_metric, apply_multichoice_metric, apply_multichoice_metric_one_token, apply_perplexity_metric, @@ -44,6 +45,7 @@ from lighteval.models.model_output import ModelReturn from lighteval.tasks.requests import ( Doc, + GreedyUntilMultiTurnRequest, GreedyUntilRequest, GreedyUntilWithLogitsRequest, LoglikelihoodRequest, @@ -410,6 +412,8 @@ def get_request_type(self) -> list[RequestType]: request_types.append(RequestType.LOGLIKELIHOOD_ROLLING) if self.has_metric_category[MetricCategory.GENERATIVE]: request_types.append(RequestType.GREEDY_UNTIL) + if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]: + request_types.append(RequestType.GREEDY_UNTIL_MULTI_TURN) if self.has_metric_category[MetricCategory.GENERATIVE_LOGPROB]: request_types.append(RequestType.GREEDY_UNTIL_WITH_LOGITS) if self.has_metric_category[MetricCategory.MULTICHOICE]: @@ -500,6 +504,17 @@ def construct_requests( choices=formatted_doc.choices, ) ] + if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]: + requests[RequestType.GREEDY_UNTIL_MULTI_TURN] += [ + GreedyUntilMultiTurnRequest( + task_name=current_task_name, + example_index=document_id_seed, + request_index=0, + context=context, + stop_sequence=self.stop_sequence, + generation_size=self.generation_size, + ) + ] return requests @@ -546,6 +561,11 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic results=results, formatted_doc=formatted_doc, metrics=self.metrics ) outputs.update(cur_outputs) + if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]: + results, cur_outputs = apply_generative_multi_turn_metric( + results=results, formatted_doc=formatted_doc, metrics=self.metrics + ) + outputs.update(cur_outputs) return outputs @@ -670,21 +690,31 @@ def create_requests_from_tasks( # noqa: C901 # to fix!! cur_task_name = f"{task_name}|{num_fewshot}" doc = task_docs[doc_id] - ctx, num_effective_few_shots = task.fewshot_sampler.fewshot_context( - task=task, - doc=doc, - num_fewshot=num_fewshot, - seed=seed, - truncate_few_shots=truncate_few_shots, - max_model_length=lm.max_length, - sampler=rnd, - tokenizer=lm.tokenizer, - use_chat_template=use_chat_template, - system_prompt=system_prompt, - ) + is_multi_turn = doc.specific is not None and len(doc.specific.get("multi_turn_queries", [])) > 0 + + if is_multi_turn: + ctx, num_effective_few_shots = task.fewshot_sampler.create_multi_turn_contexts( + doc, use_chat_template, system_prompt, lm.tokenizer + ) + doc.specific["multi_turn_queries_context"] = ctx + else: + ctx, num_effective_few_shots = task.fewshot_sampler.fewshot_context( + task=task, + doc=doc, + num_fewshot=num_fewshot, + seed=seed, + truncate_few_shots=truncate_few_shots, + max_model_length=lm.max_length, + sampler=rnd, + tokenizer=lm.tokenizer, + use_chat_template=use_chat_template, + system_prompt=system_prompt, + ) + doc.num_effective_few_shots = num_effective_few_shots doc.num_asked_few_shots = num_fewshot doc.ctx = ctx + # Constructing the requests docs[TaskExampleId(cur_task_name, doc_id_seed)] = doc reqs = task.construct_requests(doc, ctx, doc_id_seed, cur_task_name) diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index 86038dc7..c4c86335 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -33,6 +33,7 @@ class RequestType(Enum): LOGLIKELIHOOD_SINGLE_TOKEN = auto() LOGLIKELIHOOD_ROLLING = auto() GREEDY_UNTIL = auto() + GREEDY_UNTIL_MULTI_TURN = auto() GREEDY_UNTIL_WITH_LOGITS = auto() @@ -120,6 +121,22 @@ class GreedyUntilRequest(Request): tokenized_context: list[int] = None +@dataclass +class GreedyUntilMultiTurnRequest(Request): + """ + Represents a request for generating text using the Greedy-Until algorithm. + + Attributes: + stop_sequence (str): The sequence of tokens that indicates when to stop generating text. + generation_size (int): The maximum number of tokens to generate. + request_type (RequestType): The type of the request, set to RequestType.GREEDY_UNTIL. + """ + + stop_sequence: str + generation_size: int + request_type = RequestType.GREEDY_UNTIL_MULTI_TURN + + @dataclass class GreedyUntilWithLogitsRequest(Request): """