From eca6926be92e3349aef4e4214d43817b24108c8d Mon Sep 17 00:00:00 2001 From: Baber Abbasi <92168766+baberabb@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:08:52 +0500 Subject: [PATCH 1/6] vllm: handle max_length better and substitute Collator (#1241) * copies max_length from huggingface * handle max_length properly * get tokens from inputs * substitute Collator for Reorderer * `batch=auto` if using data_parallel * nit * cleanup * update code comments * `ray.shutdown()` after calling method if data_parallel_size > 1 --------- Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> --- lm_eval/models/vllm_causallms.py | 223 ++++++++++++++++--------------- 1 file changed, 118 insertions(+), 105 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 177124f112..6912428ed1 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -1,30 +1,36 @@ import copy -from collections import defaultdict from importlib.util import find_spec from typing import List, Literal, Optional, Tuple, Union from tqdm import tqdm -from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import LM from lm_eval.api.registry import register_model +from lm_eval.utils import ( + Collator, + divide, + eval_logger, + get_rolling_token_windows, + make_disjoint_window, +) try: + import ray from ray.util.multiprocessing import Pool from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer except ModuleNotFoundError: pass -eval_logger = utils.eval_logger +eval_logger = eval_logger # adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727 -def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]): - # gpu_id = [x for x in gpu_id] - # os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id) +def run_inference_one_model( + model_args: dict, sampling_params, requests: List[List[int]] +): llm = LLM(**model_args) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) @@ -43,7 +49,7 @@ def __init__( tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_revision: Optional[str] = None, tensor_parallel_size: int = 1, - quantization: Optional[Literal["awq"]] = None, + quantization: Optional[str] = None, max_gen_toks: int = 256, swap_space: int = 4, batch_size: Union[str, int] = 1, @@ -86,10 +92,23 @@ def __init__( "quantization": quantization, "seed": int(seed), } + self.batch_size = ( + "auto" + if isinstance(batch_size, str) and "auto" in batch_size + else batch_size + ) if self.data_parallel_size <= 1: self.model = LLM(**self.model_args) else: self.model_args["worker_use_ray"] = True + self.batch_size = "auto" + eval_logger.info("Manual batching is not compatible with data parallelism.") + + from transformers import AutoConfig + + self._config = AutoConfig.from_pretrained( + pretrained, trust_remote_code=trust_remote_code, revision=revision + ) self.tokenizer = get_tokenizer( tokenizer if tokenizer else pretrained, tokenizer_mode=tokenizer_mode, @@ -97,7 +116,6 @@ def __init__( tokenizer_revision=tokenizer_revision, ) - self.batch_size = "auto" if batch_size.startswith("auto:") else batch_size self._max_gen_toks = max_gen_toks @property @@ -109,9 +127,18 @@ def eot_token_id(self): def max_length(self): if self._max_length: # if max length manually set, return it return self._max_length - if hasattr(self.tokenizer, "model_max_length"): - return self.tokenizer.model_max_length - return self._DEFAULT_MAX_LENGTH + if self.data_parallel_size <= 1: + return self.model.llm_engine.model_config.max_model_len + else: + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self._config, attr): + return getattr(self._config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH @property def max_gen_toks(self): @@ -157,13 +184,13 @@ def _model_generate( temperature=0, prompt_logprobs=2, max_tokens=1 ) if self.data_parallel_size > 1: - requests = [ - list(x) for x in utils.divide(requests, self.data_parallel_size) - ] + requests = [list(x) for x in divide(requests, self.data_parallel_size)] inputs = [(self.model_args, sampling_params, req) for req in requests] with Pool(self.data_parallel_size) as pool: results = pool.starmap(run_inference_one_model, inputs) + # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. + ray.shutdown() # flatten results return [item for sublist in results for item in sublist] @@ -172,7 +199,6 @@ def _model_generate( sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False, ) - return outputs def _encode_pair( @@ -212,8 +238,8 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: for (string,) in tqdm([req.args for req in requests]): rolling_token_windows = list( map( - utils.make_disjoint_window, - utils.get_rolling_token_windows( + make_disjoint_window, + get_rolling_token_windows( token_list=self.tok_encode(string), prefix_token=self.eot_token_id, max_seq_len=self.max_length - 1, @@ -236,8 +262,7 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: return loglikelihoods def generate_until(self, requests: List[Instance]) -> List[str]: - res = defaultdict(list) - re_ords = {} + res = [] # batch tokenize contexts context, all_gen_kwargs = zip(*(req.args for req in requests)) @@ -253,84 +278,73 @@ def _collate_gen(_requests): # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end - return -len(_requests[0][1]), tuple(_requests[0][1]) + return -len(_requests[0][1]), _requests[0][0] # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. - grouper = utils.Grouper(requests, lambda x: str(x[1])) - for key, reqs in grouper.get_grouped().items(): - # within each set of reqs for given kwargs, we reorder by token length, descending. - re_ords[key] = utils.Reorderer(requests, _collate_gen) + re_ords = Collator(requests, _collate_gen, grouping=True) + chunks = re_ords.get_batched( + n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None + ) pbar = tqdm(total=len(requests), disable=(self.rank != 0)) # for each different set of kwargs, we execute all requests, by batch. - for key, re_ord in re_ords.items(): - chunks = utils.chunks( - re_ord.get_reordered(), - n=int(self.batch_size) if self.batch_size != "auto" else 0, - fn=None, - ) - for chunk in chunks: - context_and_encoding, all_gen_kwargs = zip(*chunk) - context, context_encoding = zip(*context_and_encoding) - # we assume all gen kwargs in the batch are the same - # this is safe to assume because the `grouper` object ensures it. - gen_kwargs = all_gen_kwargs[0] - # unpack our keyword arguments. - until = None - if isinstance(gen_kwargs, dict): - kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 - if "until" in kwargs.keys(): - until = kwargs.pop("until") - if isinstance(until, str): - until = [until] - elif not isinstance(until, list): - raise ValueError( - f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" - ) - else: - raise ValueError( - f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}" - ) - if not until: - until = [self.tokenizer.decode(self.eot_token_id)] - if "max_gen_toks" in kwargs.keys(): - max_gen_toks = kwargs.pop("max_gen_toks") - else: - max_gen_toks = self.max_gen_toks - - # set the max length in tokens of inputs ("context_enc") - # max len for inputs = max length, minus room to generate the max new tokens - max_ctx_len = self.max_length - max_gen_toks - context_encoding = [x[-max_ctx_len:] for x in context_encoding] - - # TODO: max_length in kwargs - - # perform batched generation - cont = self._model_generate( - requests=context_encoding, - generate=True, - max_tokens=max_gen_toks, - stop=until, - **kwargs, + for chunk in chunks: + context_and_encoding, all_gen_kwargs = zip(*chunk) + context, context_encoding = zip(*context_and_encoding) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + until = None + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + if "until" in kwargs.keys(): + until = kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}" ) + if not until: + until = [self.tokenizer.decode(self.eot_token_id)] + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + context_encoding = [x[-max_ctx_len:] for x in context_encoding] + + # perform batched generation + cont = self._model_generate( + requests=context_encoding, + generate=True, + max_tokens=max_gen_toks, + stop=until, + **kwargs, + ) - # cache generations - for output, context in zip(cont, context): - generated_text = output.outputs[0].text - res[key].append(generated_text) - self.cache_hook.add_partial( - "generate_until", (context, gen_kwargs), generated_text - ) - pbar.update(1) - - # reorder this group of results back to original unsorted form - res[key] = re_ord.get_original(res[key]) + # cache generations + for output, context in zip(cont, context): + generated_text = output.outputs[0].text + res.append(generated_text) + self.cache_hook.add_partial( + "generate_until", (context, gen_kwargs), generated_text + ) + pbar.update(1) pbar.close() - - return grouper.get_original(res) + # reorder all group of results back to original unsorted form + return re_ords.get_original(res) def _loglikelihood_tokens( self, @@ -343,16 +357,15 @@ def _collate(x): toks = x[1] + x[2] return -len(toks), tuple(toks) - re_ord = utils.Reorderer(requests, _collate) - - chunks = utils.chunks( - re_ord.get_reordered(), - n=int(self.batch_size) if self.batch_size != "auto" else 0, - fn=None, + # Reorder requests by length and batch + re_ord = Collator(requests, sort_fn=_collate) + chunks = re_ord.get_batched( + n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None ) + pbar = tqdm(total=len(requests), disable=disable_tqdm) for chunk in chunks: - inps = [] + inputs = [] ctxlens = [] for cache_key, context_enc, continuation_enc in chunk: inp = (context_enc + continuation_enc)[-(self.max_length) :] @@ -360,18 +373,18 @@ def _collate(x): 0, len(context_enc) + len(continuation_enc) - (self.max_length) ) - inps.append(inp) + inputs.append(inp) ctxlens.append(ctxlen) - outputs = self._model_generate(requests=inps, generate=False) + outputs = self._model_generate(requests=inputs, generate=False) - for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip( - outputs, ctxlens, chunk + for output, ctxlen, (cache_key, _, _), inp in zip( + outputs, ctxlens, chunk, inputs ): answer = self._parse_logprobs( - (context_enc + continuation_enc), - output, - ctxlen, + tokens=inp, + outputs=output, + ctxlen=ctxlen, ) res.append(answer) @@ -379,7 +392,7 @@ def _collate(x): # partial caching if cache_key is not None: self.cache_hook.add_partial("loglikelihood", cache_key, answer) - pbar.update(1) + pbar.update(1) pbar.close() return re_ord.get_original(res) @@ -388,9 +401,9 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: """Process logprobs and tokens. :param tokens: list - Tokens from context+continuations + Input tokens (potentially left-truncated) :param outputs: RequestOutput - Contains prompt + Contains prompt_logprobs :param ctxlen: int Length of context (so we can slice them away and only keep the predictions) :return: @@ -400,11 +413,11 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: Whether argmax matches given continuation exactly """ - # prompt_logprobs = [None, {}*len(context-1)] + # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on. continuation_logprobs_dicts = outputs.prompt_logprobs # Calculate continuation_logprobs - # assume ctxlen always > 1 + # assume ctxlen always >= 1 continuation_logprobs = sum( logprob_dict.get(token) for token, logprob_dict in zip( From e7c03d0c73587e6617eb98a7e5234df09d48b3c3 Mon Sep 17 00:00:00 2001 From: Lintang Sutawika Date: Thu, 4 Jan 2024 21:11:22 +0700 Subject: [PATCH 2/6] Remove self.dataset_path post_init process (#1243) * Remove self.dataset_path post_init process * Update task.py * Update task.py --- lm_eval/api/task.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index 217349426c..a92cf151fc 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -1,7 +1,6 @@ import abc import ast import logging -import os import random import re from collections.abc import Callable @@ -87,12 +86,6 @@ class TaskConfig(dict): ] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks def __post_init__(self) -> None: - if self.dataset_path and os.path.exists(os.path.dirname(self.dataset_path)): - import inspect - from importlib import import_module - - self.dataset_path = inspect.getfile(import_module(self.dataset_path)) - if self.generation_kwargs is not None: if self.output_type != "generate_until": eval_logger.warning( From 28bb45fbfc6872372a220759c36a0b36b05831de Mon Sep 17 00:00:00 2001 From: JorgeDeCorte <91887115+JorgeDeCorte@users.noreply.github.com> Date: Fri, 5 Jan 2024 01:37:09 +0100 Subject: [PATCH 3/6] Add multilingual HellaSwag task (#1228) * add hellaswag_nl * add other languages and update readme to hellaswag * refactor as new task * update readme * add endline to yaml files and readme.md * add group, change folder location and update yaml file * rename default hellaswag yaml file * fix whitespace error in some labels * downgrade log level of whitespace checking --------- Co-authored-by: JorgeDeCorte Co-authored-by: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> --- lm_eval/api/task.py | 6 +-- .../okapi/hellaswag_multilingual/README.md | 48 +++++++++++++++++++ .../hellaswag_multilingual/_hellaswag_yaml | 21 ++++++++ .../hellaswag_multilingual/hellaswag_ar.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_bn.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ca.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_da.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_de.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_es.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_eu.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_fr.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_gu.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_hi.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_hr.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_hu.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_hy.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_id.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_it.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_kn.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ml.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_mr.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ne.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_nl.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_pt.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ro.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ru.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_sk.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_sr.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_sv.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_ta.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_te.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_uk.yaml | 6 +++ .../hellaswag_multilingual/hellaswag_vi.yaml | 6 +++ .../okapi/hellaswag_multilingual/utils.py | 24 ++++++++++ 34 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/README.md create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/_hellaswag_yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ar.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_bn.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ca.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_da.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_de.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_es.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_eu.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_fr.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_gu.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hi.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hr.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hu.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hy.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_id.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_it.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_kn.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ml.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_mr.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ne.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_nl.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_pt.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ro.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ru.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sk.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sr.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sv.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ta.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_te.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_uk.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_vi.yaml create mode 100644 lm_eval/tasks/okapi/hellaswag_multilingual/utils.py diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index a92cf151fc..f148b3b10a 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -698,11 +698,11 @@ def __init__( ) if delimiter_has_whitespace and choice_has_whitespace: - eval_logger.warning( - f'Both target_delimiter and target choice: "{choice}" have whitespace' + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace' ) elif (not delimiter_has_whitespace) and (not choice_has_whitespace): - eval_logger.warning( + eval_logger.debug( f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' ) diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/README.md b/lm_eval/tasks/okapi/hellaswag_multilingual/README.md new file mode 100644 index 0000000000..5af16562e0 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/README.md @@ -0,0 +1,48 @@ +# Multilingual HellaSwag + +### Paper + +Title: `Okapi: Instruction-tuned Large Language Models in Multiple Languages with Reinforcement Learning from Human Feedback` + +Abstract: https://arxiv.org/abs/2307.16039 + +A key technology for the development of large language models (LLMs) involves instruction tuning that helps align the models' responses with human expectations to realize impressive learning abilities. Two major approaches for instruction tuning characterize supervised fine-tuning (SFT) and reinforcement learning from human feedback (RLHF), which are currently applied to produce the best commercial LLMs (e.g., ChatGPT). To improve the accessibility of LLMs for research and development efforts, various instruction-tuned open-source LLMs have also been introduced recently, e.g., Alpaca, Vicuna, to name a few. However, existing open-source LLMs have only been instruction-tuned for English and a few popular languages, thus hindering their impacts and accessibility to many other languages in the world. Among a few very recent work to explore instruction tuning for LLMs in multiple languages, SFT has been used as the only approach to instruction-tune LLMs for multiple languages. This has left a significant gap for fine-tuned LLMs based on RLHF in diverse languages and raised important questions on how RLHF can boost the performance of multilingual instruction tuning. To overcome this issue, we present Okapi, the first system with instruction-tuned LLMs based on RLHF for multiple languages. Okapi introduces instruction and response-ranked data in 26 diverse languages to facilitate the experiments and development of future multilingual LLM research. We also present benchmark datasets to enable the evaluation of generative LLMs in multiple languages. Our experiments demonstrate the advantages of RLHF for multilingual instruction over SFT for different base models and datasets. Our framework and resources are released at this https URL. + +Homepage: `https://github.com/nlp-uoregon/Okapi` + + +### Citation + +``` +@article{dac2023okapi, + title={Okapi: Instruction-tuned Large Language Models in Multiple Languages with Reinforcement Learning from Human Feedback}, + author={Dac Lai, Viet and Van Nguyen, Chien and Ngo, Nghia Trung and Nguyen, Thuat and Dernoncourt, Franck and Rossi, Ryan A and Nguyen, Thien Huu}, + journal={arXiv e-prints}, + pages={arXiv--2307}, + year={2023} +} +``` + +### Groups and Tasks + +#### Groups + +- hellaswag_multilingual + +#### Tasks + +- `hellaswag_{ar,bn,ca,da,de,es,eu,fr,gu,hi,hr,hu,hy,id,it,kn,ml,mr,ne,nl,pt,ro,ru,sk,sr,sv,ta,te,uk,vi}` + + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/_hellaswag_yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/_hellaswag_yaml new file mode 100644 index 0000000000..5be1d03ae0 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/_hellaswag_yaml @@ -0,0 +1,21 @@ +group: + - hellaswag_multilingual +dataset_path: null +dataset_name: null +output_type: multiple_choice +training_split: null +validation_split: validation +test_split: null +process_docs: !function utils.process_docs +doc_to_text: "query" +doc_to_target: "{{label.lstrip()}}" +doc_to_choice: "choices" +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ar.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ar.yaml new file mode 100644 index 0000000000..c88534613d --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ar.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ar +dataset_path: alexandrainst/m_hellaswag +dataset_name: ar +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_bn.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_bn.yaml new file mode 100644 index 0000000000..67999829cd --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_bn.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_bn +dataset_path: alexandrainst/m_hellaswag +dataset_name: bn +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ca.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ca.yaml new file mode 100644 index 0000000000..0607ca9443 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ca.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ca +dataset_path: alexandrainst/m_hellaswag +dataset_name: ca +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_da.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_da.yaml new file mode 100644 index 0000000000..608f8d5206 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_da.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_da +dataset_path: alexandrainst/m_hellaswag +dataset_name: da +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_de.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_de.yaml new file mode 100644 index 0000000000..6c103a8321 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_de.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_de +dataset_path: alexandrainst/m_hellaswag +dataset_name: de +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_es.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_es.yaml new file mode 100644 index 0000000000..78fa793d56 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_es.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_es +dataset_path: alexandrainst/m_hellaswag +dataset_name: es +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_eu.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_eu.yaml new file mode 100644 index 0000000000..7fdbaae7c2 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_eu.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_eu +dataset_path: alexandrainst/m_hellaswag +dataset_name: eu +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_fr.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_fr.yaml new file mode 100644 index 0000000000..d592478c81 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_fr.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_fr +dataset_path: alexandrainst/m_hellaswag +dataset_name: fr +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_gu.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_gu.yaml new file mode 100644 index 0000000000..0908b82381 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_gu.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_gu +dataset_path: alexandrainst/m_hellaswag +dataset_name: gu +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hi.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hi.yaml new file mode 100644 index 0000000000..c211078550 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hi.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_hi +dataset_path: alexandrainst/m_hellaswag +dataset_name: hi +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hr.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hr.yaml new file mode 100644 index 0000000000..7e4b547b00 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hr.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_hr +dataset_path: alexandrainst/m_hellaswag +dataset_name: hr +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hu.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hu.yaml new file mode 100644 index 0000000000..57bd4d7129 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hu.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_hu +dataset_path: alexandrainst/m_hellaswag +dataset_name: hu +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hy.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hy.yaml new file mode 100644 index 0000000000..a00c55231c --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_hy.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_hy +dataset_path: alexandrainst/m_hellaswag +dataset_name: hy +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_id.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_id.yaml new file mode 100644 index 0000000000..4c3b39fdb2 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_id.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_id +dataset_path: alexandrainst/m_hellaswag +dataset_name: id +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_it.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_it.yaml new file mode 100644 index 0000000000..97be88b8e3 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_it.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_it +dataset_path: alexandrainst/m_hellaswag +dataset_name: it +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_kn.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_kn.yaml new file mode 100644 index 0000000000..40d924c85e --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_kn.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_kn +dataset_path: alexandrainst/m_hellaswag +dataset_name: kn +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ml.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ml.yaml new file mode 100644 index 0000000000..6337b4f682 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ml.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ml +dataset_path: alexandrainst/m_hellaswag +dataset_name: ml +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_mr.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_mr.yaml new file mode 100644 index 0000000000..d4fbaff49e --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_mr.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_mr +dataset_path: alexandrainst/m_hellaswag +dataset_name: mr +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ne.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ne.yaml new file mode 100644 index 0000000000..75d12fb26c --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ne.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ne +dataset_path: alexandrainst/m_hellaswag +dataset_name: ne +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_nl.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_nl.yaml new file mode 100644 index 0000000000..2c3ed2e8d6 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_nl.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_nl +dataset_path: alexandrainst/m_hellaswag +dataset_name: nl +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_pt.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_pt.yaml new file mode 100644 index 0000000000..7082b5a615 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_pt.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_pt +dataset_path: alexandrainst/m_hellaswag +dataset_name: pt +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ro.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ro.yaml new file mode 100644 index 0000000000..04b8d13747 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ro.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ro +dataset_path: alexandrainst/m_hellaswag +dataset_name: ro +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ru.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ru.yaml new file mode 100644 index 0000000000..0a10a5e989 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ru.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ru +dataset_path: alexandrainst/m_hellaswag +dataset_name: ru +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sk.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sk.yaml new file mode 100644 index 0000000000..7b831f755f --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sk.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_sk +dataset_path: alexandrainst/m_hellaswag +dataset_name: sk +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sr.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sr.yaml new file mode 100644 index 0000000000..9dfae80cf0 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sr.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_sr +dataset_path: alexandrainst/m_hellaswag +dataset_name: sr +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sv.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sv.yaml new file mode 100644 index 0000000000..8ca7d56778 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_sv.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_sv +dataset_path: alexandrainst/m_hellaswag +dataset_name: sv +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ta.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ta.yaml new file mode 100644 index 0000000000..16d4894290 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_ta.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_ta +dataset_path: alexandrainst/m_hellaswag +dataset_name: ta +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_te.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_te.yaml new file mode 100644 index 0000000000..92a846b6e6 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_te.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_te +dataset_path: alexandrainst/m_hellaswag +dataset_name: te +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_uk.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_uk.yaml new file mode 100644 index 0000000000..d675fb448b --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_uk.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_uk +dataset_path: alexandrainst/m_hellaswag +dataset_name: uk +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_vi.yaml b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_vi.yaml new file mode 100644 index 0000000000..6722d853e5 --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/hellaswag_vi.yaml @@ -0,0 +1,6 @@ +include: _hellaswag_yaml +task: hellaswag_vi +dataset_path: alexandrainst/m_hellaswag +dataset_name: vi +training_split: null +validation_split: val diff --git a/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py b/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py new file mode 100644 index 0000000000..62c0c23bcd --- /dev/null +++ b/lm_eval/tasks/okapi/hellaswag_multilingual/utils.py @@ -0,0 +1,24 @@ +import datasets +import re + + +def preprocess(text): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + out_doc = { + "query": preprocess(doc["activity_label"] + ": " + ctx), + "choices": [preprocess(ending) for ending in doc["endings"]], + "gold": int(doc["label"]), + } + return out_doc + + return dataset.map(_process_doc) From 28ec7fa950346b5a895e85e1f3edd5648168acc4 Mon Sep 17 00:00:00 2001 From: Sam Passaglia <8333102+passaglia@users.noreply.github.com> Date: Fri, 5 Jan 2024 19:01:38 +0900 Subject: [PATCH 4/6] Do not escape ascii is logging outputs (#1246) * do not ensure ascii * Update __main__.py --------- Co-authored-by: Lintang Sutawika --- lm_eval/__main__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lm_eval/__main__.py b/lm_eval/__main__.py index 7fbee0dc73..5b362e38e8 100644 --- a/lm_eval/__main__.py +++ b/lm_eval/__main__.py @@ -248,7 +248,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: if results is not None: if args.log_samples: samples = results.pop("samples") - dumped = json.dumps(results, indent=2, default=_handle_non_serializable) + dumped = json.dumps( + results, indent=2, default=_handle_non_serializable, ensure_ascii=False + ) if args.show_config: print(dumped) @@ -264,7 +266,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ) filename = path.joinpath(f"{output_name}.jsonl") samples_dumped = json.dumps( - samples[task_name], indent=2, default=_handle_non_serializable + samples[task_name], + indent=2, + default=_handle_non_serializable, + ensure_ascii=False, ) filename.open("w").write(samples_dumped) From cf6a832143427f7b5e243c9fd4c5da5ff728be6f Mon Sep 17 00:00:00 2001 From: Lintang Sutawika Date: Mon, 8 Jan 2024 21:11:11 +0700 Subject: [PATCH 5/6] fixed fewshot loading for multiple input tasks (#1255) --- lm_eval/api/task.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index f148b3b10a..455f512af8 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -787,16 +787,19 @@ def fewshot_context(self, doc, num_fewshot): ) example = self.doc_to_text(doc) - if isinstance(example, str): - return labeled_examples + example - elif isinstance(example, list): - return [labeled_examples + ex for ex in example] - elif isinstance(example, int): - if self.config.doc_to_choice is not None: - choices = self.doc_to_choice(doc) - return labeled_examples + choices[example] - else: - return labeled_examples + str(example) + if self.multiple_input: + return labeled_examples + else: + if isinstance(example, str): + return labeled_examples + example + elif isinstance(example, list): + return [labeled_examples + ex for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + else: + return labeled_examples + str(example) def apply_filters(self): if hasattr(self, "_filters"): @@ -952,7 +955,9 @@ def construct_requests( if self.multiple_input: # If there are multiple inputs, choices are placed in the ctx cont = self.doc_to_target(doc) - arguments = [(ctx, f"{target_delimiter}{cont}") for ctx in choices] + arguments = [ + (ctx + choice, f"{target_delimiter}{cont}") for choice in choices + ] else: # Otherwise they are placed in the continuation arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] From ecb1df28f6de2495da560c21b891a00133372337 Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Mon, 8 Jan 2024 10:10:01 -0500 Subject: [PATCH 6/6] Revert citation (#1257) Over a dozen papers have used the updated citation block, but Google Scholar has noticed none of them. Since it does understand this citation, I think we should use it going forward until we have a way to ensure the newer citations are actually logged. --- README.md | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 462127a144..27b2033c40 100644 --- a/README.md +++ b/README.md @@ -301,14 +301,10 @@ The best way to get support is to open an issue on this repo or join the [Eleuth ## Cite as ``` -@misc{eval-harness, - author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, - title = {A framework for few-shot language model evaluation}, - month = 12, - year = 2023, - publisher = {Zenodo}, - version = {v0.4.0}, - doi = {10.5281/zenodo.10256836}, - url = {https://zenodo.org/records/10256836} +@article{gao2021framework, + title={A framework for few-shot language model evaluation}, + author={Gao, Leo and Tow, Jonathan and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and McDonell, Kyle and Muennighoff, Niklas and others}, + journal={Version v0. 0.1. Sept}, + year={2021} } ```