Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding inference endpoints models #12

Merged
merged 18 commits into from
Feb 7, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix model type
  • Loading branch information
clefourrier committed Feb 6, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 5cfeb9a69055b2285ffa9daacbcb5d28c5a073c6
8 changes: 4 additions & 4 deletions src/lighteval/models/adapter_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import nullcontext

import torch
from transformers import AutoModel, PreTrainedTokenizer
from transformers import AutoModelForCausalLM, PreTrainedTokenizer

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
@@ -20,7 +20,7 @@ def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConf
# (= the parent model, not the model of interest)
return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config)

def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModel:
def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM:
"""Returns a PeftModel from a base model and a version fined tuned using PEFT."""
torch_dtype = _get_dtype(config.dtype, self._config)
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
@@ -31,7 +31,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

if self.accelerator.is_local_main_process if self.accelerator is not None else nullcontext():
hlog(f"Loading model from {adapter_weights} and applying adapter to {config.base_model}")
base = AutoModel.from_pretrained(
base = AutoModelForCausalLM.from_pretrained(
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
)
# Should pass revision
@@ -43,7 +43,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

hlog(f"Loading model from {merged_path}")

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
merged_path,
max_memory=max_memory,
device_map=device_map,
32 changes: 14 additions & 18 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
import transformers
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
@@ -51,7 +51,7 @@ def __init__(
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self.accelerator = config.accelerator
self._batch_size = config.batch_size
self._max_length = config.max_length
self._max_length = self._init_max_length(config.max_length)
self._config = config.init_configs(env_config)

self.add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
@@ -140,7 +140,7 @@ def _create_auto_model(self, config: BaseModelConfig, env_config: EnvConfig) ->
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
torch_dtype = _get_dtype(config.dtype, self._config)

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
config.pretrained,
revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
max_memory=max_memory,
@@ -207,16 +207,11 @@ def _create_auto_tokenizer_with_name(

return tokenizer

@property
def eot_token(self) -> str:
return self.tokenizer.eos_token

@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id

@property
def max_length(self) -> int:
return self._max_length

def _init_max_length(self, max_length) -> int:
"""Return the maximum sequence length of the model.
NOTE: Different model configurations have different max sequence length
attribute names.
@@ -231,10 +226,10 @@ def max_length(self) -> int:
based on the model's configuration or tokenizer's model_max_length attribute.

Returns:
None
int: Max length to use depending on the available args and config
"""
if self._max_length is not None:
return self._max_length
if max_length is not None:
return max_length
# Try to get the sequence length from the model config.
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")

@@ -263,6 +258,7 @@ def disable_tqdm(self) -> bool:
disable_tqdm = bool(not self.accelerator.is_main_process)
return disable_tqdm

# Tokenization helpers
def tok_encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
@@ -374,7 +370,7 @@ def greedy_until(
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = request.stop_sequence + [self.eot_token]
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
dataset = GenerativeTaskDataset(requests=requests, dataset_splits=dataset_splits)
starting_batch_size = STARTING_BATCH_SIZE
results = []
@@ -488,7 +484,7 @@ def loglikelihood(
"""
for request in requests:
if request.context == "":
request.tokenized_context = [self.eot_token_id]
request.tokenized_context = [self.tokenizer.eos_token_id]
request.tokenized_continuation = self.tok_encode(request.choice)
else:
# DO NOT CHANGE THE FOLLOWING LINE!
@@ -507,7 +503,7 @@ def loglikelihood_rolling(
"""This function is used to compute the log likelihood of the context for perplexity metrics."""

for request in requests: # tuple of one elem
request.tokenized_context = [self.eot_token_id] # Fake context
request.tokenized_context = [self.tokenizer.eos_token_id] # Fake context
request.tokenized_continuation = self.tok_encode(request.context)
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
# tokenized_reqs.append((("", context), fake_context_enc, context_enc))

@@ -735,7 +731,7 @@ def loglikelihood_single_token(
"""
for request in requests:
if request.context == "":
request.tokenized_context = [self.eot_token_id]
request.tokenized_context = [self.tokenizer.eos_token_id]
else:
request.tokenized_context = self.tok_encode(request.context)

10 changes: 5 additions & 5 deletions src/lighteval/models/delta_model.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import torch
from tqdm import tqdm
from transformers import AutoModel
from transformers import AutoModelForCausalLM

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
@@ -15,7 +15,7 @@ def _create_auto_model(
self,
config: DeltaModelConfig,
env_config: EnvConfig,
) -> AutoModel:
) -> AutoModelForCausalLM:
"""Returns a model created by adding the weights of a delta model to a base model."""
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
torch_dtype = _get_dtype(config.dtype, self._config)
@@ -26,10 +26,10 @@ def _create_auto_model(

if self.accelerator.is_main_process if self.accelerator is not None else nullcontext():
hlog(f"Loading base and delta models from {config.base_model} and {delta_model}")
base = AutoModel.from_pretrained(
base = AutoModelForCausalLM.from_pretrained(
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
)
delta = AutoModel.from_pretrained(
delta = AutoModelForCausalLM.from_pretrained(
delta_model,
revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
torch_dtype=torch.float16,
@@ -46,7 +46,7 @@ def _create_auto_model(

hlog(f"Loading delta-applied model from {delta_model}-delta-applied")

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
merged_path,
max_memory=max_memory,
device_map=device_map,
Loading