Skip to content

Commit

Permalink
Merge branch 'main' into small-improvement-in-rouge
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jul 17, 2024
2 parents 051f778 + 44f9a46 commit 5932c05
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
6 changes: 2 additions & 4 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
Expand Down Expand Up @@ -88,9 +88,7 @@ def __init__(
self.multichoice_continuations_start_space = config.multichoice_continuations_start_space

# We are in DP (and launch the script with `accelerate launch`)
if not config.model_parallel and config.quantization_config is None:
# might need to use accelerate instead
# self.model = config.accelerator.prepare(self.model)
if not config.model_parallel and not isinstance(config.quantization_config, BitsAndBytesConfig):
hlog(f"Using Data Parallelism, putting model on device {self._device}")
self.model = self.model.to(self._device)

Expand Down
33 changes: 25 additions & 8 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class BaseModelConfig:
Use `dtype="auto"` to derive the type from the model's weights.
device (Union[int, str]): device to use for model training.
quantization_config (Optional[BitsAndBytesConfig]): quantization
configuration for the model. Needed for 4-bit and 8-bit precision.
configuration for the model, manually provided to load a normally floating point
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
Expand Down Expand Up @@ -144,13 +145,29 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon
cache_dir=env_config.cache_dir,
token=env_config.token,
)
if getattr(auto_config, "quantization_config", False) and self.quantization_config is None:
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
hlog(
"`quantization_config` is None but was found in the model's config, using the one found in config.json"
)
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)

# Gathering the model's automatic quantization config, if available
try:
model_auto_quantization_config = auto_config.quantization_config
hlog("An automatic quantization config was found in the model's config. Using it to load the model")
except (AttributeError, KeyError):
model_auto_quantization_config = None

if model_auto_quantization_config is not None:
if self.quantization_config is not None:
# We don't load models quantized by default with a different user provided conf
raise ValueError("You manually requested quantization on a model already quantized!")

# We add the quantization to the model params we store
if model_auto_quantization_config["quant_method"] == "gptq":
if not is_autogptq_available():
raise ImportError(NO_AUTOGPTQ_ERROR_MSG)
auto_config.quantization_config["use_exllama"] = None
self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True)
elif model_auto_quantization_config["quant_method"] == "bitsandbytes":
if not is_bnb_available():
raise ImportError(NO_BNB_ERROR_MSG)
self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config)

return auto_config

Expand Down

0 comments on commit 5932c05

Please sign in to comment.