-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix `ChatGeneration.format_input` exception * Bump `datasets` to 2.16.0 or higher To be able to efficiently use the cache via `load_dataset` whenever there's no connection * Add `benchmarks/arena_hard.py` (WIP) * Update `_get_hf_dataset_info` error message * Add `ArenaHard` and `ArenaHardResults` docstrings * Catch `ImportError` on `ArenaHardResults.load` * Add `ArenaHard` and `ArenaHardEval` imports * Add `arena-hard` extras * Install `arena-hard` extra in `test.yml` * Update `arena-hard` extra dependencies * Fix circular import in `arena_hard.py` * Apply suggestions from code review * Add missing examples in docstrings & fix type-hints * Add some future TODOs * Update `install_dependencies.sh` to install `arena-hard` extra * Add `ArenaHard` and `ArenaHardResults` unit tests
- Loading branch information
1 parent
0e8c752
commit d32d664
Showing
10 changed files
with
532 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import re | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from typing_extensions import override | ||
|
||
from distilabel.steps import GlobalStep, StepInput | ||
from distilabel.steps.tasks.base import Task | ||
from distilabel.steps.tasks.typing import ChatType | ||
from distilabel.steps.typing import StepOutput | ||
|
||
|
||
class ArenaHard(Task): | ||
"""This `Task` is based on the "From Live Data to High-Quality Benchmarks: The | ||
Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for | ||
instruction-tuned LLMs that contains 500 challenging user queries. GPT-4 is used | ||
as the judge to compare the model responses against a baseline model, which defaults | ||
to `gpt-4-0314`. | ||
Note: | ||
Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena | ||
among popular open-ended LLM benchmarks. | ||
Input columns: | ||
- instruction (`str`): The instruction to evaluate the responses. | ||
- generations (`List[str]`): The responses generated by two, and only two, LLMs. | ||
Output columns: | ||
- evaluation (`str`): The evaluation of the responses generated by the LLMs. | ||
- score (`str`): The score extracted from the evaluation. | ||
- model_name (`str`): The model name used to generate the evaluation. | ||
Categories: | ||
- benchmark | ||
References: | ||
- [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/) | ||
- [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main) | ||
Examples: | ||
Evaluate two assistant responses for a given instruction using Arean Hard prompts: | ||
```python | ||
from distilabel.pipeline import Pipeline | ||
from distilabel.steps import CombineColumns, LoadDataFromDicts | ||
from distilabel.steps.tasks import ArenaHard, TextGeneration | ||
with Pipeline() as pipeline: | ||
load_data = LoadDataFromDicts( | ||
data=[{"instruction": "What is the capital of France?"}], | ||
) | ||
text_generation_a = TextGeneration( | ||
llm=..., # LLM instance | ||
output_mappings={"model_name": "generation_model"}, | ||
) | ||
text_generation_b = TextGeneration( | ||
llm=..., # LLM instance | ||
output_mappings={"model_name": "generation_model"}, | ||
) | ||
combine = CombineColumns( | ||
columns=["generation", "generation_model"], | ||
output_columns=["generations", "generation_models"], | ||
) | ||
arena_hard = ArenaHard( | ||
llm=..., # LLM instance | ||
) | ||
load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard | ||
``` | ||
""" | ||
|
||
@property | ||
def inputs(self) -> List[str]: | ||
"""The inputs required by this task are the `instruction` and the `generations`, | ||
which are the responses generated by two, and only two, LLMs.""" | ||
return ["instruction", "generations"] | ||
|
||
def format_input(self, input: Dict[str, Any]) -> ChatType: | ||
"""This method formats the input data as a `ChatType` using the prompt defined | ||
by the Arena Hard benchmark, which consists on a `system_prompt` plus a template | ||
for the user first message that contains the `instruction` and both `generations`. | ||
""" | ||
return [ | ||
{ | ||
"role": "system", | ||
"content": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt displayed below. You will be given assistant A's answer and assistant B's answer. Your job is to evaluate which assistant's answer is better.\n\nBegin your evaluation by generating your own answer to the prompt. You must provide your answers before judging any answers.\n\nWhen evaluating the assistants' answers, compare both assistants' answers with your answer. You must identify and correct any mistakes or inaccurate information.\n\nThen consider if the assistant's answers are helpful, relevant, and concise. Helpful means the answer correctly responds to the prompt or follows the instructions. Note when user prompt has any ambiguity or more than one interpretation, it is more helpful and appropriate to ask for clarifications or more information from the user than providing an answer based on assumptions. Relevant means all parts of the response closely connect or are appropriate to what is being asked. Concise means the response is clear and not verbose or excessive.\n\nThen consider the creativity and novelty of the assistant's answers when needed. Finally, identify any missing important information in the assistants' answers that would be beneficial to include when responding to the user prompt.\n\nAfter providing your explanation, you must output only one of the following choices as your final verdict with a label:\n\n1. Assistant A is significantly better: [[A>>B]]\n2. Assistant A is slightly better: [[A>B]]\n3. Tie, relatively the same: [[A=B]]\n4. Assistant B is slightly better: [[B>A]]\n5. Assistant B is significantly better: [[B>>A]]\n\nExample output: \"My final verdict is tie: [[A=B]]\".", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"<|User Prompt|>\n{input['instruction']}\n\n<|The Start of Assistant A's Answer|>\n{input['generations'][0]}\n<|The End of Assistant A's Answer|>\n\n<|The Start of Assistant B's Answer|>\n{input['generations'][1]}\n<|The End of Assistant B's Answer|>", | ||
}, | ||
] | ||
|
||
@property | ||
def outputs(self) -> List[str]: | ||
"""The outputs generated by this task are the `evaluation`, the `score` and | ||
the `model_name` (which is automatically injected within the `process` method | ||
of the parent task).""" | ||
return ["evaluation", "score", "model_name"] | ||
|
||
def format_output( | ||
self, | ||
output: Union[str, None], | ||
input: Union[Dict[str, Any], None] = None, | ||
) -> Dict[str, Any]: | ||
"""This method formats the output generated by the LLM as a Python dictionary | ||
containing the `evaluation` which is the raw output generated by the LLM (consisting | ||
of the judge LLM alternate generation for the given instruction, plus an explanation | ||
on the evaluation of the given responses; plus the `score` extracted from the output. | ||
Args: | ||
output: the raw output of the LLM. | ||
input: the input to the task. Is provided in case it needs to be used to enrich | ||
the output if needed. | ||
Returns: | ||
A dict with the keys `evaluation` with the raw output which contains the LLM | ||
evaluation and the extracted `score` if possible. | ||
""" | ||
if output is None: | ||
return {"evaluation": None, "score": None} | ||
pattern = re.compile(r"\[\[([AB<>=]+)\]\]") | ||
match = pattern.search(output) | ||
if match is None: | ||
return {"evaluation": output, "score": None} | ||
return {"evaluation": output, "score": match.group(1)} | ||
|
||
|
||
class ArenaHardResults(GlobalStep): | ||
"""This `Step` is based on the "From Live Data to High-Quality Benchmarks: The | ||
Arena-Hard Pipeline" paper that presents Arena Hard, which is a benchmark for | ||
instruction-tuned LLMs that contains 500 challenging user queries. This step is | ||
a `GlobalStep` that should run right after the `ArenaHard` task to calculate the | ||
ELO scores for the evaluated models. | ||
Note: | ||
Arena-Hard-Auto has the highest correlation and separability to Chatbot Arena | ||
among popular open-ended LLM benchmarks. | ||
References: | ||
- [From Live Data to High-Quality Benchmarks: The Arena-Hard Pipeline](https://lmsys.org/blog/2024-04-19-arena-hard/) | ||
- [`arena-hard-auto`](https://github.com/lm-sys/arena-hard-auto/tree/main) | ||
Examples: | ||
Rate the ELO scores for two assistant responses for a given an evaluation / comparison between both using Arean Hard prompts: | ||
```python | ||
from distilabel.pipeline import Pipeline | ||
from distilabel.steps import CombineColumns, LoadDataFromDicts | ||
from distilabel.steps.tasks import ArenaHard, TextGeneration | ||
with Pipeline() as pipeline: | ||
load_data = LoadDataFromDicts( | ||
data=[{"instruction": "What is the capital of France?"}], | ||
) | ||
text_generation_a = TextGeneration( | ||
llm=..., # LLM instance | ||
output_mappings={"model_name": "generation_model"}, | ||
) | ||
text_generation_b = TextGeneration( | ||
llm=..., # LLM instance | ||
output_mappings={"model_name": "generation_model"}, | ||
) | ||
combine = CombineColumns( | ||
columns=["generation", "generation_model"], | ||
output_columns=["generations", "generation_models"], | ||
) | ||
arena_hard = ArenaHard( | ||
llm=..., # LLM instance | ||
) | ||
arena_hard_results = ArenaHardResults( | ||
custom_model_column="generation_models", | ||
custom_weights={"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3}, | ||
) | ||
load_data >> [text_generation_a, text_generation_b] >> combine >> arena_hard >> arena_hard_results | ||
``` | ||
""" | ||
|
||
custom_model_column: Optional[str] = None | ||
custom_weights: Dict[str, int] = {"A>B": 1, "A>>B": 3, "B>A": 1, "B>>A": 3} | ||
|
||
def load(self) -> None: | ||
"""Ensures that the required dependencies are installed.""" | ||
super().load() | ||
|
||
try: | ||
import numpy as np # noqa: F401 | ||
import pandas as pd # noqa: F401 | ||
from sklearn.linear_model import LogisticRegression # noqa: F401 | ||
except ImportError as e: | ||
raise ImportError( | ||
"In order to run `ArenaHardResults`, the `arena-hard` extra dependencies" | ||
" must be installed i.e. `numpy`, `pandas`, and `scikit-learn`.\n" | ||
"Please install the dependencies by running `pip install distilabel[arena-hard]`." | ||
) from e | ||
|
||
# TODO: the `evaluation` is not really required as an input, so it could be removed, since | ||
# only `score` is used / required | ||
@property | ||
def inputs(self) -> List[str]: | ||
"""The inputs required by this step are the `evaluation` and the `score` generated | ||
by the `ArenaHard` task. Since this step does use the identifiers `model_a` and `model_b`, | ||
optionally one can set `custom_model_column` to use the model names if existing within | ||
the input data, ideally this value should be `model_name` if connected from the `ArenaHard` | ||
step.""" | ||
columns = ["evaluation", "score"] | ||
if self.custom_model_column: | ||
columns.append(self.custom_model_column) | ||
return columns | ||
|
||
@override | ||
def process(self, inputs: StepInput) -> StepOutput: # type: ignore | ||
"""This method processes the inputs generated by the `ArenaHard` task to calculate the | ||
win rates for each of the models to evaluate. Since this step inherits from the `GlobalStep`, | ||
it will wait for all the input batches to be processed, and then the output will be yielded in | ||
case there's a follow up step, since this step won't modify the received inputs. | ||
Args: | ||
inputs: A list of Python dictionaries with the inputs of the task. | ||
Yields: | ||
A list of Python dictionaries with the outputs of the task. | ||
References: | ||
- https://github.com/lm-sys/arena-hard-auto/blob/main/show_result.py | ||
""" | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
models = ["A", "B"] | ||
if self.custom_model_column: | ||
models = inputs[0][self.custom_model_column] | ||
|
||
# TODO: the battles are only calculated for the first game, even though the official | ||
# implementation also covers the possibility of a second game (not within the released | ||
# dataset yet) | ||
battles = pd.DataFrame() | ||
for input in inputs: | ||
output = { | ||
# TODO: "question_id": input["question_id"], | ||
"model_a": models[0], | ||
"model_b": models[1], | ||
} | ||
if input["score"] in ["A>B", "A>>B"]: | ||
output["winner"] = models[0] | ||
rows = [output] * self.custom_weights[input["score"]] | ||
elif input["score"] in ["B>A", "B>>A"]: | ||
output["winner"] = models[1] | ||
rows = [output] * self.custom_weights[input["score"]] | ||
elif input["score"] == "A=B": | ||
output["winner"] = "tie" | ||
rows = [output] | ||
else: | ||
continue | ||
|
||
battles = pd.concat([battles, pd.DataFrame(rows)]) | ||
|
||
models = pd.concat([battles["model_a"], battles["model_b"]]).unique() | ||
models = pd.Series(np.arange(len(models)), index=models) | ||
|
||
battles = pd.concat([battles, battles], ignore_index=True) | ||
p = len(models.index) | ||
n = battles.shape[0] | ||
|
||
X = np.zeros([n, p]) | ||
X[np.arange(n), models[battles["model_a"]]] = +np.log(10) | ||
X[np.arange(n), models[battles["model_b"]]] = -np.log(10) | ||
|
||
Y = np.zeros(n) | ||
Y[battles["winner"] == "model_a"] = 1.0 | ||
|
||
tie_idx = battles["winner"] == "tie" | ||
tie_idx[len(tie_idx) // 2 :] = False | ||
Y[tie_idx] = 1.0 | ||
|
||
lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-8) # type: ignore | ||
lr.fit(X, Y) | ||
|
||
# The ELO scores are calculated assuming that the reference is `gpt-4-0314` | ||
# with an starting ELO of 1000, so that the evaluated models are compared with | ||
# `gtp-4-0314` only if it's available within the models | ||
elo_scores = 400 * lr.coef_[0] + 1000 | ||
# TODO: we could parametrize the reference / anchor model, but left as is to be faithful to the | ||
# original implementation | ||
if "gpt-4-0314" in models.index: | ||
elo_scores += 1000 - elo_scores[models["gpt-4-0314"]] | ||
|
||
output = pd.Series(elo_scores, index=models.index).sort_values(ascending=False) | ||
self._logger.info(f"Arena Hard ELO: {output}") | ||
|
||
# Here only so that if follow up steps are connected the inputs are preserved, | ||
# since this step doesn't modify nor generate new inputs | ||
yield inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.