diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 05397f092d..11140dad47 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -19,7 +19,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from trl.extras.dataset_formatting import get_formatting_func_from_dataset -from trl.models.utils import ChatMlSpecialTokens, setup_chat_format +from trl.models.utils import setup_chat_format class DatasetFormattingTestCase(unittest.TestCase): @@ -119,29 +119,24 @@ class SetupChatFormatTestCase(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") - # remove built-in chat_template to simulate a model having no chat_template + # Remove built-in chat_template to simulate a model having no chat_template self.tokenizer.chat_template = None def test_setup_chat_format(self): - modified_model, modified_tokenizer = setup_chat_format( - self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64 - ) + _, modified_tokenizer = setup_chat_format(self.model, self.tokenizer, format="chatml") - _chatml = ChatMlSpecialTokens() # Check if special tokens are correctly set self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") - self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>") self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>") - self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token) - self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token) - self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token) - self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0) + + def test_setup_chat_format_with_resize(self): + modified_model, _ = setup_chat_format(self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123) + + # Check that the input embeddings have been resized to a multiple of 123 + self.assertEqual((modified_model.get_input_embeddings().weight.shape[0] % 123), 0) def test_example_with_setup_model(self): - modified_model, modified_tokenizer = setup_chat_format( - self.model, - self.tokenizer, - ) + _, modified_tokenizer = setup_chat_format(self.model, self.tokenizer) messages = [ {"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hello"}, diff --git a/trl/models/utils.py b/trl/models/utils.py index caa9b1df74..5cdc416183 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch.nn as nn from packaging import version @@ -76,58 +76,64 @@ def chat_template(self): def setup_chat_format( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, - format: Optional[Literal["chatml"]] = "chatml", + format: str = "chatml", resize_to_multiple_of: Optional[int] = None, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: """ - Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the + embedding layer of the model based on the new special tokens. - If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. + If the model already has a chat template, this will throw an error. If you want to overwrite it, please set + `tokenizer.chat_template` to `None` before calling this function. Args: - model (`~transformers.PreTrainedModel`): The model to be modified. - tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. - format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". - resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None. + model (`~transformers.PreTrainedModel`): + Model to be modified. + tokenizer (`~transformers.PreTrainedTokenizer`): + Tokenizer to be modified. + format (`str`, *optional*, defaults to `"chatml"`): + Format to be set. This can be either one of `{"chatml"}`. + resize_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): + If not None, the model's embedding layer will be resized to a multiple of this number. Returns: - model (`~transformers.PreTrainedModel`): The modified model. - tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. + model (`~transformers.PreTrainedModel`): + Mdified model. + tokenizer (`~transformers.PreTrainedTokenizer`): + Modified tokenizer. """ - # check if model already had a chat template + # Check if model already had a chat template if tokenizer.chat_template is not None: raise ValueError( - "Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None" + "Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None " + "before calling this function." ) - # check if format available and retrieve if format not in FORMAT_MAPPING: - raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") + raise ValueError(f"Format {chat_format} not supported. Supported formats are: {', '.join(FORMAT_MAPPING.keys())}") chat_format = FORMAT_MAPPING[format]() - # set special tokens and them + # Set special tokens and chat template + tokenizer.chat_template = chat_format.chat_template tokenizer.eos_token = chat_format.eos_token - tokenizer.pad_token = chat_format.pad_token tokenizer.bos_token = chat_format.bos_token tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) - # set chat format for tokenizer - tokenizer.chat_template = chat_format.chat_template - # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 + # Resize embedding layer + # This can lead to significant speedup, see https://x.com/karpathy/status/1621578354024677377 model.resize_token_embeddings( - len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None + new_num_tokens=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + pad_to_multiple_of=resize_to_multiple_of, ) + # Update the model config to use the new eos & bos tokens - if getattr(model, "config", None) is not None: - model.config.pad_token_id = tokenizer.pad_token_id - model.config.bos_token_id = tokenizer.bos_token_id - model.config.eos_token_id = tokenizer.eos_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.eos_token_id = tokenizer.eos_token_id + # Update the generation config to use the new eos & bos token - if getattr(model, "generation_config", None) is not None: - model.generation_config.bos_token_id = tokenizer.bos_token_id - model.generation_config.eos_token_id = tokenizer.eos_token_id - model.generation_config.pad_token_id = tokenizer.pad_token_id + model.generation_config.bos_token_id = tokenizer.bos_token_id + model.generation_config.eos_token_id = tokenizer.eos_token_id return model, tokenizer