Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiny improvements to endpoint_model.py, base_model.py,... #219

Merged
Merged
17 changes: 10 additions & 7 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _async_process_request(
grammar=grammar,
max_new_tokens=max_tokens,
stop_sequences=stop_tokens,
# truncate=,
truncate=self.max_length,
)

return generated_text
Expand All @@ -416,7 +416,7 @@ def _process_request(
grammar=grammar,
max_new_tokens=max_tokens,
stop_sequences=stop_tokens,
# truncate=,
truncate=self.max_length,
)

return generated_text
Expand Down Expand Up @@ -492,7 +492,7 @@ def greedy_until(

for _, _ in tqdm(
dataset.splits_start_end_iterator(),
total=self.DATASET_SPLITS,
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
Expand All @@ -514,12 +514,15 @@ def greedy_until(
responses = asyncio.run(self._async_process_batch_generate(batch))
else:
responses = self._process_batch_generate(batch)
for response in responses:
for i, response in enumerate(responses):
results.append(
GenerativeResponse(
result=response.generated_text,
logits=[item.logprob for item in response.details.prefill] if returns_logits else None,
truncated_tokens_count=-1,
generated_tokens=[token.id for token in response.details.tokens],
truncated_tokens_count=max(
len(self.tokenizer.encode(batch[i].context)) - self.max_length, 0
),
padded_tokens_count=-1,
)
)
Expand All @@ -538,7 +541,7 @@ def loglikelihood(

for _, _ in tqdm(
dataset.splits_start_end_iterator(),
total=self.DATASET_SPLITS,
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
Expand Down Expand Up @@ -589,7 +592,7 @@ def loglikelihood_rolling(

for _, _ in tqdm(
dataset.splits_start_end_iterator(),
total=self.DATASET_SPLITS,
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
Expand Down
4 changes: 3 additions & 1 deletion src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig):
return model


def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, env_config: EnvConfig):
def load_model_with_inference_endpoints(
config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
):
logger.info("Spin up model using inference endpoint.")
model = InferenceEndpointModel(config=config, env_config=env_config)
return model
Expand Down
5 changes: 1 addition & 4 deletions src/lighteval/models/transformers/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,7 @@ def greedy_until(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
input_mask=tokenized["attention_mask"],
truncated=[
len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0
for c in context
],
truncated=[max(len(c) - tokenized["input_ids"].shape[1], 0) for c in context],
padded=[sum(mask == 0) for mask in tokenized["attention_mask"]],
)

Expand Down
Loading