Skip to content

Commit

Permalink
Merge branch 'fix-precommit-issues' of https://github.com/GyoukChu/lm…
Browse files Browse the repository at this point in the history
…-evaluation-harness into fix-precommit-issues
  • Loading branch information
GyoukChu committed Dec 12, 2024
2 parents 39a4081 + 046aeb7 commit 13ccedb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ Extras dependencies can be installed via `pip install -e ".[NAME]"`
| hf_transfer | For speeding up HF Hub file downloads |
| ifeval | For running the IFEval task |
| ibm_watsonx_ai | For using IBM watsonx.ai model apis |

| neuronx | For running on AWS inf2 instances |
| mamba | For loading Mamba SSM models |
| math | For running math task answer checking |
Expand Down
44 changes: 42 additions & 2 deletions lm_eval/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,13 @@ def batch_loglikelihood_requests(
for chunk in chunks:
for cache_key, context_enc, continuation_enc in chunk:
# max_length - 1 as we always have 1 token for generation
inp = (context_enc + continuation_enc)[-(self.max_length) :]
inp = (context_enc + continuation_enc)[-self.max_length :]
if len(inp) < len(context_enc + continuation_enc):
eval_logger.warning(
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
)
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
0, len(context_enc) + len(continuation_enc) - self.max_length
)

inputs.append(inp)
Expand Down Expand Up @@ -594,6 +598,24 @@ def _collate_gen(_requests):
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks

encodings_list = [x[-max_context_len:] for x in encodings_list]

if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
Expand Down Expand Up @@ -625,6 +647,24 @@ def _collate_gen(_requests):
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
if self.tokenized_requests:
max_gen_toks = all_gen_kwargs[0].get(
"max_gen_toks", self._max_gen_toks
)
max_context_len = self.max_length - max_gen_toks

encodings_list = [x[-max_context_len:] for x in encodings_list]

if any(
len(x) + max_gen_toks > self.max_length for x in encodings_list
):
eval_logger.warning(
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
)
else:
eval_logger.info(
"Tokenized requests are disabled. Context + generation length is not checked."
)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/nemo_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def __init__(
**kwargs,
):
try:
from lightning.pytorch.trainer.trainer import Trainer
from nemo.collections.nlp.modules.common.text_generation_utils import (
generate,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from pytorch_lightning.trainer.trainer import Trainer

self.generate = generate
except ModuleNotFoundError as exception:
Expand Down

0 comments on commit 13ccedb

Please sign in to comment.