Skip to content

Commit

Permalink
Merge branch 'main' into fix-drop
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jul 17, 2024
2 parents 875064c + 44f9a46 commit 6ad484d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
11 changes: 5 additions & 6 deletions src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def evaluate( # noqa: C901
:return
Dictionary of results
"""
# A request output tupe is a Tuple where the first element is the index of
# A request output tuple is a Tuple where the first element is the index of
# the request for one document of one task i.e.
# task: "arc_easy", doc: "0"# request: "0" -> request_index = 0,
# We can have multiple requests per doc for multi choice tasks for example.
Expand All @@ -75,8 +75,11 @@ def evaluate( # noqa: C901
)
example_id_response_dict: dict[TaskExampleId, list[RequestIndexModelResponseTuple]] = collections.defaultdict(list)

for request_type, requests in requests_dict.items():
for request_type in RequestType:
if request_type not in requests_dict:
continue
hlog(f"Running {request_type} requests")
requests = requests_dict[request_type]
# These are all the request type from the request factory at the moment
if request_type == RequestType.LOGLIKELIHOOD:
full_resps = lm.loglikelihood(requests, override_bs=override_bs)
Expand All @@ -99,10 +102,6 @@ def evaluate( # noqa: C901

# ===== unpack results and sort back in order and return control to Task =====
for task_example_id, prediction_list in example_id_response_dict.items():
# ===== Unpack the request =====
prediction_list.sort(
key=lambda x: x.request_index
) # When we use Loglikelihood for several tokens we have all the options here
model_responses = [x.model_response for x in prediction_list]
cur_task_name = task_example_id.task_name.rsplit("|", 1)[0]

Expand Down
8 changes: 3 additions & 5 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,22 @@ def apply_generative_metric(

def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[Metric]):
outputs = {}
if len(formatted_doc.choices) != len(results):
raise ValueError("Length of results is not equal to the length of the choices")
mc_results = results[: len(formatted_doc.choices)]
if len(formatted_doc.choices) <= 1:
raise ValueError(
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)

# Todo: make better system with return_bool_score instead of taking first element
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
choices_logprob = [mc_results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

for metric in metrics:
if metric.category == MetricCategory.MULTICHOICE:
outputs.update(
metric.compute(choices_logprob=choices_logprob, gold_ixs=gold_ixs, formatted_doc=formatted_doc)
)

return results, outputs
return results[len(formatted_doc.choices) :], outputs


def apply_multichoice_metric_one_token(results: list[ModelReturn], formatted_doc: Doc, metrics: list[Metric]):
Expand Down
8 changes: 2 additions & 6 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
Expand Down Expand Up @@ -88,9 +88,7 @@ def __init__(
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space

# We are in DP (and launch the script with `accelerate launch`)
if not config.model_parallel and config.quantization_config is None:
# might need to use accelerate instead
# self.model = config.accelerator.prepare(self.model)
if not config.model_parallel and not isinstance(config.quantization_config, BitsAndBytesConfig):
hlog(f"Using Data Parallelism, putting model on device {self._device}")
self.model = self.model.to(self._device)

Expand Down Expand Up @@ -267,8 +265,6 @@ def _init_max_length(self, max_length) -> int:
if hasattr(self._config, attr):
return getattr(self._config, attr)

if hasattr(self.tokenizer, "model_max_length"):
return self.tokenizer.model_max_length
# Default max sequence length setting for when no `max_length` is provided
# or no max length config setting is found in the model or tokenizer.
return 2048
Expand Down
33 changes: 25 additions & 8 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class BaseModelConfig:
Use `dtype="auto"` to derive the type from the model's weights.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model. Needed for 4-bit and 8-bit precision.
configuration for the model, manually provided to load a normally floating point
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
Expand Down Expand Up @@ -144,13 +145,29 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon
cache_dir=env_config.cache_dir,
token=env_config.token,
)
if getattr(auto_config, "quantization_config", False) and self.quantization_config is None:
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
hlog(
"`quantization_config` is None but was found in the model's config, using the one found in config.json"
)
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)

# Gathering the model's automatic quantization config, if available
try:
model_auto_quantization_config = auto_config.quantization_config
hlog("An automatic quantization config was found in the model's config. Using it to load the model")
except (AttributeError, KeyError):
model_auto_quantization_config = None

if model_auto_quantization_config is not None:
if self.quantization_config is not None:
# We don't load models quantized by default with a different user provided conf
raise ValueError("You manually requested quantization on a model already quantized!")

# We add the quantization to the model params we store
if model_auto_quantization_config["quant_method"] == "gptq":
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
auto_config.quantization_config["use_exllama"] = None
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)
elif model_auto_quantization_config["quant_method"] == "bitsandbytes":
if not is_bnb_available():
raise ImportError(NO_BNB_ERROR_MSG)
self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config)

return auto_config

Expand Down
22 changes: 11 additions & 11 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,16 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic
results=results, formatted_doc=formatted_doc, metrics=self.metrics
)
outputs.update(cur_outputs)
if self.has_metric_category[MetricCategory.MULTICHOICE]:
results, cur_outputs = apply_multichoice_metric(
results=results, formatted_doc=formatted_doc, metrics=self.metrics
)
outputs.update(cur_outputs)
if self.has_metric_category[MetricCategory.MULTICHOICE_ONE_TOKEN]:
results, cur_outputs = apply_multichoice_metric_one_token(
results=results, formatted_doc=formatted_doc, metrics=self.metrics
)
outputs.update(cur_outputs)
if self.has_metric_category[MetricCategory.PERPLEXITY]:
results, cur_outputs = apply_perplexity_metric(
results=results, formatted_doc=formatted_doc, metrics=self.metrics
Expand All @@ -557,16 +567,6 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic
max_num_samples=max(self.num_samples),
)
outputs.update(cur_outputs)
if self.has_metric_category[MetricCategory.MULTICHOICE]:
results, cur_outputs = apply_multichoice_metric(
results=results, formatted_doc=formatted_doc, metrics=self.metrics
)
outputs.update(cur_outputs)
if self.has_metric_category[MetricCategory.MULTICHOICE_ONE_TOKEN]:
results, cur_outputs = apply_multichoice_metric_one_token(
results=results, formatted_doc=formatted_doc, metrics=self.metrics
)
outputs.update(cur_outputs)
if (
self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]
or self.has_metric_category[MetricCategory.LLM_AS_JUDGE]
Expand Down Expand Up @@ -643,7 +643,7 @@ def create_requests_from_tasks( # noqa: C901
) -> Tuple[dict[RequestType, list[Request]], dict[TaskExampleId, Doc]]:
"""
Takes a task dict and a fewshot dict and returns a dict of requests, a dict
of docs, and a dict of requests origins. The construction of prompts and
of docs, and a dict of requests origins. The construction of prompts and
thus the managing of few shots is done here.
Args:
Expand Down

0 comments on commit 6ad484d

Please sign in to comment.