Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes wikitext prompts + some patches on tg models #64

Merged
merged 6 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import xxhash

from lighteval.logging.hierarchical_logger import hlog, hlog_warn
from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.metrics import MetricCategory
from lighteval.metrics.stderr import get_stderr_function
from lighteval.models.model_loader import ModelInfo
Expand Down Expand Up @@ -440,7 +440,7 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int =
try:
metric_result = task.aggregation()[metric_name](metric_values)
except OverflowError:
hlog(f"{task_name} {metric_name} OVERFLOW ERROR")
hlog_warn(f"{task_name}, {metric_name} got an OVERFLOW ERROR when aggregating.")
metric_result = float("nan")

if isinstance(metric_result, dict): # in which cases do we get a dict here?
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ def compute(self, items: list[PerplexityCorpusMetricInput]):
if self.metric_type == "perplexity":
return math.exp(-np.mean(logprobs))
if self.metric_type == "weighted_perplexity":
return math.exp(-np.average(logprobs, weights=weights))
return math.exp(-sum(logprobs) / sum(weights))
if self.metric_type == "bits_per_byte":
return -np.average(logprobs, weights=weights) / math.log(2)
return -sum(logprobs) / sum(weights) * 1 / math.log(2)
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,10 @@ def loglikelihood(
responses = self.__process_batch_logprob(batch)
for ix, response in enumerate(responses):
len_choice = len(batch[ix].tokenized_continuation)
logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
results.append(
LoglikelihoodReturn(
result=[
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
],
result=sum(logits),
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
generated_tokens=[t.id for t in response.details.prefill[-len_choice:]],
truncated_tokens_count=-1,
Expand Down Expand Up @@ -329,9 +328,10 @@ def loglikelihood_rolling(
else:
responses = self.__process_batch_logprob(batch, rolling=True)
for response in responses:
logits = [t.logprob for t in response.details.tokens[:-1]]
results.append(
LoglikelihoodReturn(
result=[t.logprob for t in response.details.tokens[:-1]],
result=sum(logits),
input_tokens=[t.id for t in response.details.prefill],
generated_tokens=[t.id for t in response.details.tokens[:-1]],
truncated_tokens_count=-1,
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class ModelReturn: # @clefourrier: could probably an abstract class, but it might make the code too complex
class ModelReturn:
result: Union[tuple, list, str]
input_tokens: list[int] = field(default_factory=list) # model inputs
generated_tokens: list[int] = field(default_factory=list) # model generations
Expand Down
9 changes: 7 additions & 2 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
# vs when it's used for the actual prompt. That's why we store whether we are currently using the
# doc for a fewshot sample (few_shots=True) or not, which then leads to the creation of a different Doc.
item["__few_shots"] = few_shots
docs.extend(as_list(self.formatter(item, self.name)))
cur_docs = self.formatter(item, self.name)
if cur_docs is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it be empty ? Not sure that's expected behaviour

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some experiments with the original wikitext, and it contains empty rows, which we want to skip - it should not happen a lot though

continue
docs.extend(as_list(cur_docs))
return docs

def fewshot_docs(self) -> list[Doc]:
Expand Down Expand Up @@ -375,7 +378,9 @@ def construct_requests(
]
if self.has_metric_category[MetricCategory.PERPLEXITY]:
requests[RequestType.LOGLIKELIHOOD_ROLLING] += [
LoglikelihoodRollingRequest(task_name=current_task_name, doc_id=document_id_seed, ctx=context)
LoglikelihoodRollingRequest(
task_name=current_task_name, example_index=document_id_seed, request_index=0, context=context
)
]
if self.has_metric_category[MetricCategory.GENERATIVE]:
requests[RequestType.GREEDY_UNTIL] += [
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/tasks_prompt_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ def wikifact(line, task_name: str = None):


def wikitext_103(line, task_name: str = None):
return Doc(task_name=task_name, query=line["text"])
return Doc(task_name=task_name, choices=[""], gold_index=0, query=line["text"])


def winogrande(line, task_name: str = None):
Expand Down
Loading