From e29931259b2b2c9939f3f96d5ef5e994dfb77cf1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 19 Aug 2024 14:59:24 -0400 Subject: [PATCH 01/89] optionally save the final FSDP model as a sharded state dict (#1828) * efficiently save very large llms when using FSDP * fix parsing and index of sharded chunks * only save fsdp on main process * debugging for rename * save sharded state dict * remove unused new param * get state dict directly * tweak acc merge fsdp to shard the weight files * sharded_state_dict alongside save_safetensors seems to hang on checkpoint save --- src/axolotl/cli/merge_sharded_fsdp_weights.py | 204 ++++++++++++++++++ src/axolotl/train.py | 21 +- .../config/models/input/v0_4_1/__init__.py | 17 ++ 3 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/cli/merge_sharded_fsdp_weights.py diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py new file mode 100644 index 0000000000..25408fd57e --- /dev/null +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -0,0 +1,204 @@ +""" +This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint +""" +import json +import logging +import os +import shutil +from pathlib import Path +from typing import Dict, Union + +import fire +import torch +import torch.distributed.checkpoint as dist_cp +import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +import transformers +from accelerate.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_torch_version, +) +from dotenv import load_dotenv +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors.torch import save_file as safe_save_file +from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + +LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") + + +class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): + """ + A custom planner to cast tensors to bfloat16 on the fly during loading. + """ + + def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument + tensor.copy_(tensor.to(torch.bfloat16)) + + +def _distributed_checkpoint_to_merged_weights( + checkpoint_dir: Union[str, Path], + save_path: str, + safe_serialization: bool = False, + max_shard_size: str = "5GB", +): + """ + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + + Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + """ + + state_dict: Dict = {} + save_path_ = Path(save_path) + save_path_.mkdir(exist_ok=True) + dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access + state_dict, + storage_reader=dist_cp.FileSystemReader(checkpoint_dir), + planner=BFloat16CastPlanner(), # pylint: disable=protected-access + no_dist=True, + ) + + # To handle if state is a dict like {model: {...}} + if len(state_dict.keys()) == 1: + state_dict = state_dict[list(state_dict)[0]] + + # Ensure all tensors are in bfloat16 + for key, value in state_dict.items(): + if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16: + state_dict[key] = value.to(torch.bfloat16) + + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor] for tensor in tensors} + + if safe_serialization: + safe_save_file( + shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"} + ) + else: + torch.save(shard, os.path.join(save_path_, shard_file)) + + if index is not None: + save_index_file = ( + SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + ) + save_index_file = os.path.join(save_path_, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as fout: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + fout.write(content) + + return save_path_ + + +def merge_fsdp_weights( + checkpoint_dir: str, + output_path: str, + safe_serialization: bool = False, + remove_checkpoint_dir: bool = False, +): + """ + Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if + `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if + `safe_serialization` else `pytorch_model.bin`. + + Note: this is a CPU-bound process. + + Args: + checkpoint_dir (`str`): + The directory containing the FSDP checkpoints (can be either the model or optimizer). + output_path (`str`): + The path to save the merged checkpoint. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the merged weights with safetensors (recommended). + remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): + Whether to remove the checkpoint directory after merging. + """ + checkpoint_dir_ = Path(checkpoint_dir) + from accelerate.state import PartialState + + if not is_torch_version(">=", "2.3.0"): + raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") + + # Verify that the checkpoint directory exists + if not checkpoint_dir_.exists(): + model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists() + optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists() + err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file." + if model_path_exists and optimizer_path_exists: + err += ( + " However, potential model and optimizer checkpoint directories exist." + ) + err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0" + err += "instead." + elif model_path_exists: + err += " However, a potential model checkpoint directory exists." + err += ( + f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead." + ) + elif optimizer_path_exists: + err += " However, a potential optimizer checkpoint directory exists." + err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead." + raise ValueError(err) + + # To setup `save` to work + state = PartialState() + if state.is_main_process: + LOG.info(f"Merging FSDP weights from {checkpoint_dir_}") + save_path = _distributed_checkpoint_to_merged_weights( + checkpoint_dir_, output_path, safe_serialization + ) + LOG.info(f"Successfully merged FSDP weights and saved to {save_path}") + if remove_checkpoint_dir: + LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") + shutil.rmtree(checkpoint_dir_) + state.wait_for_everyone() + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.merge_lora = True + + parsed_cfg = load_cfg( + config, + **kwargs, + ) + + fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=str(Path(parsed_cfg.output_dir) / "merged"), + safe_serialization=True, + ) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b8890d4f7a..b21b0b269c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import transformers.modelcard from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.utils import save_fsdp_model from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore @@ -194,9 +195,12 @@ def terminate_handler(_, __, model_weakref): if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled: - trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") - LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") + if cfg.fsdp_final_state_dict_type: + state_dict_type = cfg.fsdp_final_state_dict_type + trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) + LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): @@ -208,7 +212,18 @@ def terminate_handler(_, __, model_weakref): # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: - trainer.save_model(cfg.output_dir) + if ( + state_dict_type == "SHARDED_STATE_DICT" + and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" + ): + save_fsdp_model( + trainer.accelerator.state.fsdp_plugin, + trainer.accelerator, + trainer.model, + cfg.output_dir, + ) + elif state_dict_type == "FULL_STATE_DICT": + trainer.save_model(cfg.output_dir) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5e690bb88e..dcc902c8c6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -628,6 +628,9 @@ class Config: deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None + fsdp_final_state_dict_type: Optional[ + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] + ] = None val_set_size: Optional[float] = Field(default=0.0) @@ -1148,6 +1151,20 @@ def check_fsdp_offload_w_8bit_optimizer(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_fsdp_sharded_state_dict_w_safetensors(cls, data): + if ( + data.get("fsdp") + and data.get("save_safetensors") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + ): + raise ValueError( + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" + ) + return data + @model_validator(mode="before") @classmethod def check_causal_lm_evals(cls, data): From 5aac4bc2846ac7379c0a52dd894dd1ba5499ec22 Mon Sep 17 00:00:00 2001 From: "Gal Cohen (galco)" Date: Tue, 20 Aug 2024 19:41:48 +0300 Subject: [PATCH 02/89] fix: dont change quant storage dtype in case of fsdp (#1837) * fix: dont change quant storage dtype in case of fsdp * fix black --------- Co-authored-by: Gal Cohen --- src/axolotl/utils/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5ac66260a7..3e8d50f5e7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -544,7 +544,9 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } - if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: + if cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + cfg.deepspeed or cfg.fsdp + ): # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 From 649c19aba31c022028bb508c1b945da9fe407e94 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Wed, 21 Aug 2024 10:36:51 -0700 Subject: [PATCH 03/89] pretrain: fix with sample_packing=false (#1841) --- src/axolotl/utils/data/pretraining.py | 4 ++-- tests/test_data.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index e056c7f509..16f38218cd 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,10 @@ def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] ) -> Dict[str, List]: res = tokenizer( - examples, + examples["text"], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, diff --git a/tests/test_data.py b/tests/test_data.py index 16af089a06..9d7f5a0412 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -35,7 +35,7 @@ def test_encode_pretraining(self): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) + result = encode_pretraining(self.tokenizer, self.max_tokens, examples) self.assertEqual(len(result["input_ids"]), 3) From 9f917245f69940bf69cf8bade13106ee8345c6aa Mon Sep 17 00:00:00 2001 From: "Gal Cohen (galco)" Date: Wed, 21 Aug 2024 20:37:17 +0300 Subject: [PATCH 04/89] feat: add jamba chat_template (#1843) * feat: add jamba chat_template * fix: black * feat: jamba fsdp+qlora --------- Co-authored-by: Gal Cohen --- examples/jamba/qlora_fsdp.yaml | 61 +++++++++++++++++++ src/axolotl/utils/chat_templates.py | 1 + .../config/models/input/v0_4_1/__init__.py | 1 + 3 files changed, 63 insertions(+) create mode 100644 examples/jamba/qlora_fsdp.yaml diff --git a/examples/jamba/qlora_fsdp.yaml b/examples/jamba/qlora_fsdp.yaml new file mode 100644 index 0000000000..2ea268344a --- /dev/null +++ b/examples/jamba/qlora_fsdp.yaml @@ -0,0 +1,61 @@ +base_model: ai21labs/Jamba-v0.1 +tokenizer_type: AutoTokenizer + +load_in_4bit: true +strict: false +use_tensorboard: true +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + chat_template: jamba + drop_system_message: true +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: jamba-fsdp-qlora-ft +save_safetensors: true +adapter: qlora +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj] +lora_target_linear: false + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: false + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index ca4334d75a..51f88b1bdf 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -27,6 +27,7 @@ def chat_templates(user_choice: str): "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } if user_choice in templates: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index dcc902c8c6..65a2c5409a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -190,6 +190,7 @@ class ChatTemplate(str, Enum): llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name + jamba = "jamba" # pylint: disable=invalid-name class LoftQConfig(BaseModel): From f07802f9fa9ae95f0b37ce626eaf21eca9fce738 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Wed, 21 Aug 2024 10:37:51 -0700 Subject: [PATCH 05/89] examples: fix tiny-llama pretrain yml syntax (#1840) --- examples/tiny-llama/pretrain.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index e501dcb8e5..010a1608a3 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -9,9 +9,9 @@ strict: false max_steps: 200 pretraining_dataset: - path: c4 - name: en - type: pretrain + - path: allenai/c4 + name: en + type: pretrain dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/model-out From 957c956f89ded7f9e8785c5ca303d4d95bd98453 Mon Sep 17 00:00:00 2001 From: "Gal Cohen (galco)" Date: Thu, 22 Aug 2024 16:22:55 +0300 Subject: [PATCH 06/89] rename jamba example (#1846) [skip ci] * rename jamba example * feat: change readme --------- Co-authored-by: Gal Cohen --- README.md | 77 +++++++++++-------- examples/jamba/README.md | 2 +- ...{qlora_fsdp.yaml => qlora_fsdp_large.yaml} | 4 +- 3 files changed, 47 insertions(+), 36 deletions(-) rename examples/jamba/{qlora_fsdp.yaml => qlora_fsdp_large.yaml} (94%) diff --git a/README.md b/README.md index a626635dc8..55a11d6c12 100644 --- a/README.md +++ b/README.md @@ -22,39 +22,49 @@ Features: ## Table of Contents -- [Introduction](#axolotl) -- [Supported Features](#axolotl-supports) -- [Quickstart](#quickstart-) -- [Environment](#environment) - - [Docker](#docker) - - [Conda/Pip venv](#condapip-venv) - - [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod - - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) - - [Windows](#windows) - - [Mac](#mac) - - [Google Colab](#google-colab) - - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack) -- [Dataset](#dataset) -- [Config](#config) - - [Train](#train) - - [Inference](#inference-playground) - - [Merge LORA to Base](#merge-lora-to-base) - - [Special Tokens](#special-tokens) - - [All Config Options](#all-config-options) -- Advanced Topics - - [Multipack](./docs/multipack.qmd) - - [RLHF & DPO](./docs/rlhf.qmd) - - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd) - - [Unsloth](./docs/unsloth.qmd) -- [Common Errors](#common-errors-) - - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) -- [Debugging Axolotl](#debugging-axolotl) -- [Need Help?](#need-help-) -- [Badge](#badge-) -- [Community Showcase](#community-showcase) -- [Contributing](#contributing-) -- [Sponsors](#sponsors-) +- [Axolotl](#axolotl) + - [Table of Contents](#table-of-contents) + - [Axolotl supports](#axolotl-supports) + - [Quickstart ⚡](#quickstart-) + - [Usage](#usage) + - [Advanced Setup](#advanced-setup) + - [Environment](#environment) + - [Docker](#docker) + - [Conda/Pip venv](#condapip-venv) + - [Cloud GPU](#cloud-gpu) + - [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) + - [LambdaLabs](#lambdalabs) + - [GCP](#gcp) + - [Windows](#windows) + - [Mac](#mac) + - [Google Colab](#google-colab) + - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) + - [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack) + - [Dataset](#dataset) + - [Config](#config) + - [All Config Options](#all-config-options) + - [Train](#train) + - [Preprocess dataset](#preprocess-dataset) + - [Multi-GPU](#multi-gpu) + - [DeepSpeed](#deepspeed) + - [FSDP](#fsdp) + - [FSDP + QLoRA](#fsdp--qlora) + - [Weights \& Biases Logging](#weights--biases-logging) + - [Special Tokens](#special-tokens) + - [Inference Playground](#inference-playground) + - [Merge LORA to base](#merge-lora-to-base) + - [Common Errors 🧰](#common-errors-) + - [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training) + - [Debugging Axolotl](#debugging-axolotl) + - [Need help? 🙋](#need-help-) + - [Badge ❤🏷️](#badge-️) + - [Community Showcase](#community-showcase) + - [Contributing 🤝](#contributing-) + - [Sponsors 🤝❤](#sponsors-) + - [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly) + - [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo) + - [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo) + - [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo) @@ -96,6 +106,7 @@ Features: | RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ | +| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ | ✅: supported ❌: not supported diff --git a/examples/jamba/README.md b/examples/jamba/README.md index 54f5d1da9c..4c9dc85a06 100644 --- a/examples/jamba/README.md +++ b/examples/jamba/README.md @@ -6,5 +6,5 @@ - ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?) - ✅ qlora single-gpu, ~51GiB VRAM - ✅ multipack -- ❓ FSDP +- ✅ FSDP - ❓ 8-bit LoRA diff --git a/examples/jamba/qlora_fsdp.yaml b/examples/jamba/qlora_fsdp_large.yaml similarity index 94% rename from examples/jamba/qlora_fsdp.yaml rename to examples/jamba/qlora_fsdp_large.yaml index 2ea268344a..28316efd57 100644 --- a/examples/jamba/qlora_fsdp.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -1,4 +1,4 @@ -base_model: ai21labs/Jamba-v0.1 +base_model: ai21labs/AI21-Jamba-1.5-Large tokenizer_type: AutoTokenizer load_in_4bit: true @@ -11,7 +11,7 @@ datasets: drop_system_message: true dataset_prepared_path: last_run_prepared val_set_size: 0.0 -output_dir: jamba-fsdp-qlora-ft +output_dir: jamba-large-fsdp-qlora-ft save_safetensors: true adapter: qlora sequence_len: 2048 From c3fc529bfc5a302e48218ec990b56b69c969127f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 11:44:45 -0400 Subject: [PATCH 07/89] numpy 2.1.0 was released, but incompatible with numba (#1849) [skip ci] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dc74b916f8..be0c4927e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ optimum==1.16.2 hf_transfer colorama numba -numpy>=1.24.4 +numpy>=1.24.4,<=2.0.1 # qlora things evaluate==0.4.1 scipy From 5b0b774e38495b85c9ce1bdbaa2803d5daab1c92 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 11:45:00 -0400 Subject: [PATCH 08/89] ensure that the bias is also in the correct dtype (#1848) [skip ci] * ensure that the bias is also in the correct dtype * add nightly for dpo-qlora-fsdp --- examples/qwen2/qlora-fsdp.yaml | 1 + src/axolotl/core/trainer_builder.py | 2 + src/axolotl/utils/models.py | 17 ++++- tests/e2e/multigpu/test_qwen2.py | 98 +++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 tests/e2e/multigpu/test_qwen2.py diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index 44f9c7e495..d61c72a378 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -72,4 +72,5 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4e8b369052..1a073ca047 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1846,6 +1846,8 @@ def build(self, total_num_steps): ) if self.cfg.fsdp: ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) + if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: + ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3e8d50f5e7..4f47d59bfb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1102,9 +1102,20 @@ def load_lora(model, cfg, inference=False, config_only=False): def ensure_dtype(model, dtype=torch.bfloat16): for name, module in model.named_modules(): + weight_mismatch = False + bias_mismatch = False try: - if module.weight.dtype != dtype: - print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") - module.to(dtype) + weight_mismatch = module.weight.dtype != dtype except AttributeError: pass + try: + bias_mismatch = module.bias.dtype != dtype + except AttributeError: + pass + + if weight_mismatch: + print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if bias_mismatch: + print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + if weight_mismatch or bias_mismatch: + module.to(dtype) diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py new file mode 100644 index 0000000000..2513be69e5 --- /dev/null +++ b/tests/e2e/multigpu/test_qwen2.py @@ -0,0 +1,98 @@ +""" +E2E tests for multigpu qwen2 +""" + +import logging +import os +import unittest +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async + +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + + +class TestMultiGPUQwen2(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_qlora_fsdp_dpo(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2-1.5B", + "load_in_4bit": True, + "rl": "dpo", + "chat_template": "chatml", + "sequence_len": 2048, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 100, + "warmup_steps": 20, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "tf32": True, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": { + "use_reentrant": False, + }, + "fsdp": [ + "full_shard", + "auto_wrap", + ], + "fsdp_config": { + "fsdp_limit_all_gathers": True, + "fsdp_offload_params": False, + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + "fsdp_cpu_ram_efficient_loading": False, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + }, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) From 9caa3eb699ef8eb1ad8c64011945d274172bfa63 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 11:45:37 -0400 Subject: [PATCH 09/89] make the train_on_eos default to turn so all eos tokens are treated the same (#1847) [skip ci] --- src/axolotl/prompt_strategies/chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index d0fad4483a..8ae668d7e9 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -357,7 +357,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "last"), + "train_on_eos": ds_cfg.get("train_on_eos", "turn"), } strategy = ChatTemplateStrategy( From 7ed92e61c26bad40e1cc2151d1d9f934816a60e4 Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:46:57 +0200 Subject: [PATCH 10/89] fix: prompt phi (#1845) [skip ci] * corecting phi system prompt * phi test * update * add test --- src/axolotl/prompters.py | 6 ++++-- tests/test_prompters.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 0ffa3e55fd..13ff450f8a 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -65,8 +65,10 @@ def match_prompt_style(self): self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" elif self.prompt_style == PromptStyle.PHI.value: self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>" - self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>" - self.system_format = "<|system|>{system}\n" + self.turn_no_input_format = ( + "<|user|>\n{instruction}<|end|>\n<|assistant|>\n" + ) + self.system_format = "<|system|>\n{system}<|end|>\n" def _build_result(self, instruction, input_text, output): # returns the full prompt from instruction and optional input diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 6c5b8f27c2..3d61398e04 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -42,6 +42,19 @@ def test_prompt_style_w_instruct(self): assert "USER:" not in res assert "ASSISTANT:" not in res + def test_prompt_style_w_phi(self): + prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value) + res = next(prompter.build_prompt("tell me a joke about the following")) + assert ( + """<|system|> +Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|> +<|user|> +tell me a joke about the following<|end|> +<|assistant|> +""" + == res + ) + def test_prompt_style_w_chat(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) res = next( From de4ea2d1f27074a54b1cc79301f426d2cc393a7f Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Thu, 22 Aug 2024 08:47:34 -0700 Subject: [PATCH 11/89] docs: minor syntax highlight fix (#1839) --- docs/unsloth.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd index 390609fd33..90cb49bafa 100644 --- a/docs/unsloth.qmd +++ b/docs/unsloth.qmd @@ -34,7 +34,7 @@ unsloth_lora_o: true ``` These options are composable and can be used with multi-gpu finetuning -``` +```yaml unsloth_cross_entropy_loss: true unsloth_rms_norm: true unsloth_rope: true From 2f8037fee6cdee318df216049d0923455d80dad6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 13:10:40 -0400 Subject: [PATCH 12/89] ensure that the hftrainer deepspeed config is set before the trainer class is ever init'ed (#1850) [skip ci] --- src/axolotl/utils/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 26796f2e53..99c10c6558 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg): def setup_deepspeed_env(cfg, stage=None): + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg): From dcbff169830017ea4eb825aae1ef243cd124beda Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 13:10:54 -0400 Subject: [PATCH 13/89] run nightly ci builds against upstream main (#1851) * run nightly ci builds against upstream main * add test badges * run the multigpu tests against nightly main builds too --- .github/workflows/multi-gpu-e2e.yml | 8 ++ .github/workflows/tests-nightly.yml | 116 ++++++++++++++++++++++++++++ README.md | 3 + cicd/Dockerfile.jinja | 8 ++ cicd/tests.py | 1 + setup.py | 2 +- 6 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/tests-nightly.yml diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index c854af9abe..91cbaf957e 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -18,6 +18,13 @@ jobs: pytorch: 2.3.1 axolotl_extras: num_gpus: 2 + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.11" + pytorch: 2.3.1 + axolotl_extras: + num_gpus: 2 + nightly_build: "true" runs-on: [self-hosted, modal] timeout-minutes: 120 steps: @@ -39,6 +46,7 @@ jobs: echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV - name: Run tests job on Modal run: | modal run cicd.multigpu diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml new file mode 100644 index 0000000000..23e48a85b0 --- /dev/null +++ b/.github/workflows/tests-nightly.yml @@ -0,0 +1,116 @@ +name: Tests +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' # Runs at 00:00 UTC every day + +jobs: + pre-commit: + name: pre-commit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: 'pip' # caching pip dependencies + - uses: pre-commit/action@v3.0.0 + env: + SKIP: no-commit-to-branch + + pytest: + name: PyTest + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python_version: ["3.10", "3.11"] + timeout-minutes: 20 + + steps: + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + cache: 'pip' # caching pip dependencies + + - name: Update requirements.txt + run: | + sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt + sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt + sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt + sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt + + - name: Install dependencies + run: | + pip3 install --upgrade pip + pip3 install --upgrade packaging + pip3 install -U -e . + pip3 install -r requirements-tests.txt + + - name: Run tests + run: | + pytest --ignore=tests/e2e/ tests/ + + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + + docker-e2e-tests: + if: github.repository_owner == 'axolotl-ai-cloud' + # this job needs to be run on self-hosted GPU runners... + runs-on: [self-hosted, modal] + timeout-minutes: 60 + needs: [pre-commit, pytest] + + strategy: + fail-fast: false + matrix: + include: + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.10" + pytorch: 2.3.1 + num_gpus: 1 + axolotl_extras: mamba-ssm + nightly_build: "true" + - cuda: 121 + cuda_version: 12.1.1 + python_version: "3.11" + pytorch: 2.3.1 + num_gpus: 1 + axolotl_extras: mamba-ssm + nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.4.0 + num_gpus: 1 + axolotl_extras: + nightly_build: "true" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install Modal + run: | + python -m pip install --upgrade pip + pip install modal==0.63.64 jinja2 + - name: Update env vars + run: | + echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV + echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV + echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV + echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV + echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV + echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV + - name: Run tests job on Modal + run: | + modal run cicd.tests diff --git a/README.md b/README.md index 55a11d6c12..46cde54f8f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Axolotl +![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg) +![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg) + Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. Features: diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 3a79883667..c245fce3ed 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}" ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}" ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}" +ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" RUN apt-get update && \ apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev @@ -23,6 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d +RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ + sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ + sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ + sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ + sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \ + fi + RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/cicd/tests.py b/cicd/tests.py index c214676378..9c2d830cb7 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -28,6 +28,7 @@ "CUDA": os.environ.get("CUDA", "121"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), + "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), } dockerfile_contents = df_template.render(**df_args) diff --git a/setup.py b/setup.py index 1d164e0a18..1b64fadaef 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def parse_requirements(): dependency_links=dependency_links, extras_require={ "flash-attn": [ - "flash-attn==2.6.2", + "flash-attn==2.6.3", ], "fused-dense-lib": [ "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib", From b33dc07a7757819f68f6ceb4bbd2deb4c59c105b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 13:13:33 -0400 Subject: [PATCH 14/89] rename nightly test and add badge (#1853) --- .github/workflows/tests-nightly.yml | 2 +- README.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 23e48a85b0..1440efe790 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -1,4 +1,4 @@ -name: Tests +name: Tests Nightly against upstream main on: workflow_dispatch: schedule: diff --git a/README.md b/README.md index 46cde54f8f..8c70da015d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Axolotl ![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg) +![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg) ![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg) Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. From fefa95e35069a01c96583853e075bf0319e55e0a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 16:39:23 -0400 Subject: [PATCH 15/89] most model types now support flash attention 2 regardless of multipack support (#1854) --- src/axolotl/monkeypatch/multipack.py | 1 + src/axolotl/utils/models.py | 14 ++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 9043520108..44fc4cb473 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -17,6 +17,7 @@ "qwen2_moe", "falcon", "phi", + "phi3", "gemma", "gemma2", "gemmoe", diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4f47d59bfb..8d24524a23 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -591,16 +591,10 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access From 328fd4b3b74902eb39149496e5d9f7d9d4c4427f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 11:40:21 -0400 Subject: [PATCH 16/89] add axolotl community license (#1862) --- src/axolotl/integrations/LICENSE.md | 58 +++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/axolotl/integrations/LICENSE.md diff --git a/src/axolotl/integrations/LICENSE.md b/src/axolotl/integrations/LICENSE.md new file mode 100644 index 0000000000..435d36d75b --- /dev/null +++ b/src/axolotl/integrations/LICENSE.md @@ -0,0 +1,58 @@ +### AXOLOTL COMMUNITY LICENSE AGREEMENT + +This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and +any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms +and conditions set forth in this Agreement. + +1. Definitions + 1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement. + 1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl, + which may be licensed separately by their respective authors and/or licensors. + 1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at + https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which + permits Plugin Integrations to integrate with the Axolotl service. +2. Grant of License + 2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge, + publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions: + - Licensee must comply with all the terms and conditions of this Agreement. + - Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial + portions of the Software. + 2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3. +3. Restrictions + 3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for + free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such + third parties to fine-tune artificial intelligence models. + 3.2 Licensee shall not: + - Use the Software for any illegal or unauthorized purpose. + - Reverse engineer, decompile, or disassemble the Software. + - Remove or modify any copyright, trademark, or other proprietary notices contained in the Software. + - Use the Software in a way that could damage, disable, overburden, or impair the functionality of the + Software or interfere with any third-party use of the Software. + 3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement. +4. Intellectual Property Rights + 4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee + acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to + Licensee. +5. Disclaimer of Warranty + 5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF + CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. +6. Termination + 6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and + conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any + copies in its possession. +7. Governing Law + 7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California, + without regards to conflicts of laws provisions thereof. +8. Entire Agreement + 8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter + hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning + the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and + Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms + on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any + material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be + bound by the terms and conditions of this Agreement. + +This Agreement was last updated on August 23, 2024. From e8ff5d5738426bf604f79af9fffaf7dc050c6c2c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 12:18:47 -0400 Subject: [PATCH 17/89] don't mess with bnb since it needs compiled wheels (#1859) --- .github/workflows/tests-nightly.yml | 1 - cicd/Dockerfile.jinja | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 1440efe790..6b35698cbf 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -42,7 +42,6 @@ jobs: sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt - sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt - name: Install dependencies run: | diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index c245fce3ed..11ce8d8baa 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -28,7 +28,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ - sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \ fi RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ From 1f686c576c40a6dd7c8c785a133ac8fc09174b2e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 12:21:51 -0400 Subject: [PATCH 18/89] Liger Kernel integration (#1861) * add initial plugin support w Liger kernel patches * integrate the input args classes * fix liger plugin and dynamic configuration class * drop untrainable samples and refactor config plugins integration * fix incorrect inputs and circular imports * fix bool comparison * fix for dropping untraibable tokens * fix licensing so liger integration is Apache 2.0 * add jamba support * pylint ignore --- .mypy.ini | 3 + requirements.txt | 2 + src/axolotl/cli/__init__.py | 6 + src/axolotl/integrations/base.py | 383 ++++++++++++++++++ src/axolotl/integrations/config.py | 65 +++ src/axolotl/integrations/liger/LICENSE | 202 +++++++++ src/axolotl/integrations/liger/__init__.py | 104 +++++ src/axolotl/integrations/liger/args.py | 32 ++ .../integrations/liger/models/jamba.py | 173 ++++++++ src/axolotl/utils/config/__init__.py | 18 +- src/axolotl/utils/models.py | 7 + src/axolotl/utils/trainer.py | 18 + 12 files changed, 1010 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/integrations/base.py create mode 100644 src/axolotl/integrations/config.py create mode 100644 src/axolotl/integrations/liger/LICENSE create mode 100644 src/axolotl/integrations/liger/__init__.py create mode 100644 src/axolotl/integrations/liger/args.py create mode 100644 src/axolotl/integrations/liger/models/jamba.py diff --git a/.mypy.ini b/.mypy.ini index ede9fef887..c6d837d3f2 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -11,6 +11,9 @@ ignore_errors = True [mypy-axolotl.models.mixtral.*] ignore_errors = True +[mypy-axolotl.integrations.liger.models.*] +ignore_errors = True + [mypy-axolotl.models.phi.*] ignore_errors = True diff --git a/requirements.txt b/requirements.txt index be0c4927e4..f5fb547a26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,8 @@ gradio==3.50.2 tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 +triton>=2.3.0 +liger-kernel mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a05ee84e97..aaa62423ca 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -27,6 +27,7 @@ from transformers.utils.import_utils import _is_package_available from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.config import ( @@ -365,6 +366,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg.axolotl_config_path = config + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) + try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py new file mode 100644 index 0000000000..d26eed90fe --- /dev/null +++ b/src/axolotl/integrations/base.py @@ -0,0 +1,383 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +Base class for all plugins. + +A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. +Plugins can be used to integrate third-party models, modify the training process, or add new features. + +To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. +""" +import importlib +import logging +from typing import List + + +class BasePlugin: + """ + Base class for all plugins. Defines the interface for plugin methods. + + Attributes: + None + + Methods: + register(cfg): Registers the plugin with the given configuration. + pre_model_load(cfg): Performs actions before the model is loaded. + post_model_load(cfg, model): Performs actions after the model is loaded. + pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. + add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. + add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. + """ + + def __init__(self): + """ + Initializes the BasePlugin. + """ + + def register(self, cfg): + """ + Registers the plugin with the given configuration. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + + def get_input_args(self): + """ + Returns a pydantic model for the plugin's input arguments. + """ + + def pre_model_load(self, cfg): + """ + Performs actions before the model is loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + + def post_model_load(self, cfg, model): + """ + Performs actions after the model is loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def pre_lora_load(self, cfg, model): + """ + Performs actions before LoRA weights are loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def post_lora_load(self, cfg, model): + """ + Performs actions after LoRA weights are loaded. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + None + """ + + def create_optimizer(self, cfg, trainer): + """ + Creates and returns an optimizer for training. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + object: The created optimizer. + """ + + def create_lr_scheduler(self, cfg, trainer, optimizer): + """ + Creates and returns a learning rate scheduler. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + optimizer (object): The optimizer for training. + + Returns: + object: The created learning rate scheduler. + """ + + def add_callbacks_pre_trainer(self, cfg, model): + """ + Adds callbacks to the trainer before training. + + Parameters: + cfg (dict): The configuration for the plugin. + model (object): The loaded model. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs + """ + + def add_callbacks_post_trainer(self, cfg, trainer): + """ + Adds callbacks to the trainer after training. + + Parameters: + cfg (dict): The configuration for the plugin. + trainer (object): The trainer object for training. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs + """ + + +def load_plugin(plugin_name: str) -> BasePlugin: + """ + Loads a plugin based on the given plugin name. + + The plugin name should be in the format "module_name.class_name". + This function splits the plugin name into module and class, imports the module, + retrieves the class from the module, and creates an instance of the class. + + Parameters: + plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name". + + Returns: + BasePlugin: An instance of the loaded plugin. + + Raises: + ImportError: If the plugin module cannot be imported. + """ + # split the plugin name into module and class + module_name, class_name = plugin_name.rsplit(".", 1) + + # import the module + module = importlib.import_module(module_name) + # instantiate the class + plugin_class = getattr(module, class_name) + # create an instance of the class + plugin = plugin_class() + + return plugin + + +class PluginManager: + """ + The PluginManager class is responsible for loading and managing plugins. + It should be a singleton so it can be accessed from anywhere in the codebase. + + Attributes: + plugins (List[BasePlugin]): A list of loaded plugins. + + Methods: + get_instance(): Static method to get the singleton instance of PluginManager. + register(plugin_name: str): Registers a new plugin by its name. + pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + """ + + plugins: List[BasePlugin] = [] + + _instance = None + + def __new__(cls): + """ + Creates a new instance of PluginManager if it doesn't exist yet. + """ + if cls._instance is None: + cls._instance = super(PluginManager, cls).__new__(cls) + cls._instance.plugins: List[BasePlugin] = [] + return cls._instance + + @staticmethod + def get_instance() -> "PluginManager": + """ + Returns the singleton instance of PluginManager. + If the instance doesn't exist, it creates a new one. + """ + if PluginManager._instance is None: + PluginManager() + return PluginManager._instance # type: ignore + + def register(self, plugin_name: str): + """ + Registers a new plugin by its name. + + Parameters: + plugin_name (str): The name of the plugin to be registered. + + Returns: + None + + Raises: + ImportError: If the plugin module cannot be imported. + """ + try: + plugin = load_plugin(plugin_name) + self.plugins.append(plugin) + except ImportError: + logging.error(f"Failed to load plugin: {plugin_name}") + + def get_input_args(self): + """ + Returns a list of Pydantic classes for all registered plugins' input arguments.' + + Returns: + list[str]: A list of Pydantic classes for all registered plugins' input arguments.' + """ + input_args = [] + for plugin in self.plugins: + input_args_from_plugin = plugin.get_input_args() + if input_args_from_plugin is not None: + input_args.append(input_args_from_plugin) + return input_args + + def pre_model_load(self, cfg): + """ + Calls the pre_model_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + None + """ + for plugin in self.plugins: + plugin.pre_model_load(cfg) + + def post_model_load(self, cfg, model): + """ + Calls the post_model_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_model_load(cfg, model) + + def pre_lora_load(self, cfg, model): + """ + Calls the pre_lora_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.pre_lora_load(cfg, model) + + def post_lora_load(self, cfg, model): + """ + Calls the post_lora_load method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_lora_load(cfg, model) + + def create_optimizer(self, cfg, trainer): + """ + Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + + Returns: + object: The created optimizer, or None if none was found. + """ + for plugin in self.plugins: + optimizer = plugin.create_optimizer(cfg, trainer) + if optimizer is not None: + return optimizer + return None + + def create_lr_scheduler(self, cfg, trainer, optimizer): + """ + Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + optimizer (object): The optimizer for training. + + Returns: + object: The created learning rate scheduler, or None if none was found. + """ + for plugin in self.plugins: + scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) + if scheduler is not None: + return scheduler + return None + + def add_callbacks_pre_trainer(self, cfg, model): + """ + Calls the add_callbacks_pre_trainer method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs. + """ + callbacks = [] + for plugin in self.plugins: + callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) + return callbacks + + def add_callbacks_post_trainer(self, cfg, trainer): + """ + Calls the add_callbacks_post_trainer method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + trainer (object): The trainer object for training. + + Returns: + List[callable]: A list of callback functions to be added to the TrainingArgs. + """ + callbacks = [] + for plugin in self.plugins: + callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) + return callbacks diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py new file mode 100644 index 0000000000..b4ffd6758f --- /dev/null +++ b/src/axolotl/integrations/config.py @@ -0,0 +1,65 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +module to handle merging the plugins' input arguments with the base configurations. + +this was moved here to prevent circular imports +""" + +from typing import Any, Dict, List + +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, +) +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlInputConfig as AxolotlInputConfigBase, +) + + +def merge_input_args(): + """ + Merges input arguments from registered plugins with the base configurations. + + This function retrieves the input arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig, + that inherit from the base configurations and include the input arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig. + """ + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + input_args: List[str] = plugin_manager.get_input_args() + plugin_classes = [] + dynamic_input = "" + for plugin_args in input_args: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + plugin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" + dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + exec( # pylint: disable=exec-used # nosec B102 + dynamic_input, globals(), namespace + ) + AxolotlInputConfig = namespace[ # pylint: disable=invalid-name + "AxolotlInputConfig" + ] + AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name + "AxolotlConfigWCapabilities" + ] + return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/integrations/liger/LICENSE b/src/axolotl/integrations/liger/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/src/axolotl/integrations/liger/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py new file mode 100644 index 0000000000..d4c1ad9a4d --- /dev/null +++ b/src/axolotl/integrations/liger/__init__.py @@ -0,0 +1,104 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for the Plugin for LIGER integraton with Axolotl. + +Liger Kernel is the collection of Triton-native kernels for LLM Training. +It is designed to be performant, correct, and light-weight. +""" +import logging + +from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.transformers.model.llama import lce_forward +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + +from axolotl.integrations.base import BasePlugin + +from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 + + +class LigerPlugin(BasePlugin): + """ + Plugin for LIGER integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.liger.LigerArgs" + + def pre_model_load(self, cfg): + if cfg.model_config_type == "llama": + from transformers.models.llama import modeling_llama + + if cfg.liger_rope: + modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_llama.LlamaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_llama.LlamaMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss + elif cfg.liger_fused_linear_cross_entropy: + modeling_llama.LlamaForCausalLM.forward = lce_forward + + elif cfg.model_config_type == "mistral": + from transformers.models.mistral import modeling_mistral + + if cfg.liger_rope: + modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_mistral.MistralRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mistral.MistralMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Mistral." + ) + + elif cfg.model_config_type == "gemma": + from transformers.models.gemma import modeling_gemma + + if cfg.liger_rope: + modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_gemma.GemmaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_gemma.GemmaMLP = LigerGEGLUMLP + if cfg.liger_cross_entropy: + modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Gemma." + ) + + elif cfg.model_config_type == "jamba": + from transformers.models.jamba import modeling_jamba + + from .models.jamba import lce_forward as jamba_lce_forward + + if cfg.liger_rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py new file mode 100644 index 0000000000..decdb37750 --- /dev/null +++ b/src/axolotl/integrations/liger/args.py @@ -0,0 +1,32 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for handling LIGER input arguments. +""" +from typing import Optional + +from pydantic import BaseModel + + +class LigerArgs(BaseModel): + """ + Input args for LIGER. + """ + + liger_rope: Optional[bool] = None + liger_rms_norm: Optional[bool] = None + liger_swiglu: Optional[bool] = None + liger_cross_entropy: Optional[bool] = None + liger_fused_linear_cross_entropy: Optional[bool] = None diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py new file mode 100644 index 0000000000..40cec63a4f --- /dev/null +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -0,0 +1,173 @@ +""" +Jamba model with LigerFusedLinearCrossEntropyLoss +""" +# pylint: disable=duplicate-code + +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.jamba.modeling_jamba import ( + _CONFIG_FOR_DOC, + JAMBA_INPUTS_DOCSTRING, + HybridMambaAttentionDynamicCache, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index ed165e89ca..82436e8d79 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -8,11 +8,14 @@ import torch from transformers.utils import is_torch_bf16_gpu_available +from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS from axolotl.utils.config.models.input.v0_4_1 import ( - SUPPORTED_METRICS, - AxolotlConfigWCapabilities, - AxolotlInputConfig, + AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, +) +from axolotl.utils.config.models.input.v0_4_1 import ( + AxolotlInputConfig as AxolotlInputConfigBase, ) from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config @@ -207,6 +210,15 @@ def normalize_cfg_datasets(cfg): def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): + AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase + AxolotlInputConfig = AxolotlInputConfigBase + + if cfg.plugins: + ( + AxolotlConfigWCapabilities, # pylint: disable=invalid-name + AxolotlInputConfig, # pylint: disable=invalid-name + ) = merge_input_args() + if capabilities: return DictDefault( dict( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8d24524a23..6261ce20fe 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -308,10 +308,17 @@ def load_model( """ Load a model for a given configuration and tokenizer. """ + base_model = cfg.base_model model_type = cfg.type_of_model model_config = load_model_config(cfg) + # load any patches from plugins + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + plugin_manager.pre_model_load(cfg) + # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 99c10c6558..f4e1fc6cb8 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -217,6 +217,24 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Dropping Long Sequences", ) + # drop samples with where the number of elements with labels not equal to -100 is zero + def drop_no_trainable_tokens(sample): + return np.sum(np.array(sample["labels"]) != -100) > 0 + + train_dataset = train_dataset.filter( + drop_no_trainable_tokens, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Drop Samples with Zero Trainable Tokens", + ) + if eval_dataset: + eval_dataset = eval_dataset.filter( + drop_no_trainable_tokens, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Drop Samples with Zero Trainable Tokens", + ) + if cfg.group_by_length: train_dataset = train_dataset.map( add_length, From da0d581a8cbe26b9a24b8e479751f1e87687ae27 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 12:37:50 -0400 Subject: [PATCH 19/89] add liger example (#1864) --- examples/llama-3/fft-8b-liger-fsdp.yaml | 76 +++++++++++++++++++++++++ examples/llama-3/fft-8b.yaml | 4 +- 2 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 examples/llama-3/fft-8b-liger-fsdp.yaml diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml new file mode 100644 index 0000000000..a64965d207 --- /dev/null +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -0,0 +1,76 @@ +base_model: NousResearch/Meta-Llama-3.1-8B + +plugins: + - axolotl.integrations.liger.LigerPlugin +liger_rope: true +liger_rms_norm: true +liger_swiglu: true +liger_fused_linear_cross_entropy: true + +strict: false + +chat_template: llama3 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_backward_prefetch: BACKWARD_PRE +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot_id|> diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index 908ef6e035..335902aac7 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -1,6 +1,4 @@ -base_model: NousResearch/Meta-Llama-3-8B -model_type: LlamaForCausalLM -tokenizer_type: AutoTokenizer +base_model: NousResearch/Meta-Llama-3.1-8B load_in_8bit: false load_in_4bit: false From 810ecd4e81a93a8cd8d740c4e59f12fa05424f69 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 14:34:03 -0400 Subject: [PATCH 20/89] add liger to readme (#1865) * add liger to readme * updates from PR feedback --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 8c70da015d..af604fad50 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ Features: - [FSDP + QLoRA](#fsdp--qlora) - [Weights \& Biases Logging](#weights--biases-logging) - [Special Tokens](#special-tokens) + - [Liger Kernel](#liger-kernel) - [Inference Playground](#inference-playground) - [Merge LORA to base](#merge-lora-to-base) - [Common Errors 🧰](#common-errors-) @@ -530,6 +531,25 @@ tokens: # these are delimiters When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary. +##### Liger Kernel + +Liger Kernel: Efficient Triton Kernels for LLM Training + +https://github.com/linkedin/Liger-Kernel + +Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training. +It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel +composes well and is compatible with both FSDP and Deepspeed. + +```yaml +plugins: + - axolotl.integrations.liger.LigerPlugin +liger_rope: true +liger_rms_norm: true +liger_swiglu: true +liger_fused_linear_cross_entropy: true +``` + ### Inference Playground Axolotl allows you to load your model in an interactive terminal playground for quick experimentation. From 77a4b9cda21deabf97515fb04788b4eec7ed7783 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 17:00:01 -0400 Subject: [PATCH 21/89] change up import to prevent AttributeError (#1863) * change up import to prevent AttributeError * tweak patching check for updated upstream --- src/axolotl/monkeypatch/llama_patch_multipack.py | 12 ++++++------ src/axolotl/monkeypatch/unsloth_.py | 6 ++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index 540c5577a0..cfd525367e 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -9,18 +9,18 @@ def hijack_llama_prepare_4d_mask(): - import transformers.modeling_attn_mask_utils - import transformers.models.llama.modeling_llama + from transformers import modeling_attn_mask_utils + from transformers.models.llama import modeling_llama - transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) - transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 5b1f0061de..3d42ad17f1 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -16,8 +16,7 @@ LOG = get_logger("axolotl.monkeypatch.unsloth") -ORIGINAL_CEL_CODE = """ if labels is not None: - # Shift so that tokens < n predict n +ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens @@ -29,8 +28,7 @@ loss = loss_fct(shift_logits, shift_labels) """ -PATCHED_CEL_CODE = """ if labels is not None: - shift_logits = logits[..., :-1, :].contiguous() +PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = fast_cross_entropy_loss( logits = shift_logits, From 22f4eafa557bc5009877443c601e40a762832c2b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 20:23:08 -0400 Subject: [PATCH 22/89] simplify logic (#1856) --- src/axolotl/utils/models.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6261ce20fe..e183301991 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -589,19 +589,12 @@ def load_model( # sample packing uses custom FA2 patch if cfg.flash_attention: - if not cfg.sample_packing: - if cfg.s2_attention: - pass - # most other models support flash attention, we can define exceptions as they come up - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - else: - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + if not cfg.sample_packing and cfg.s2_attention: + pass + model_kwargs["attn_implementation"] = "flash_attention_2" + model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif cfg.sdp_attention: model_kwargs["attn_implementation"] = "sdpa" model_config._attn_implementation = "sdpa" # pylint: disable=protected-access From f245964f22f5bceeed27cb4e0374f1c56c6d63d5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 25 Aug 2024 12:31:40 -0400 Subject: [PATCH 23/89] better handling of llama-3 tool rolw (#1782) --- src/axolotl/prompters.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 13ff450f8a..18b73e725e 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -352,9 +352,12 @@ def _build_result(self, source): "Please help us by creating an Issue to add support for this conversation type." ) - role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( - ROLE=from_role - ) + if self._conversation.name in ["llama3"]: + role = from_role + else: + role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format( + ROLE=from_role + ) if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if ( From 8e29bdefdd32757dd952f9a295e961c3f1c70ba2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 25 Aug 2024 17:54:02 -0400 Subject: [PATCH 24/89] Spectrum plugin (#1866) --- src/axolotl/integrations/spectrum/LICENSE | 202 ++++++++++++++++++ src/axolotl/integrations/spectrum/README.md | 21 ++ src/axolotl/integrations/spectrum/__init__.py | 102 +++++++++ src/axolotl/integrations/spectrum/args.py | 29 +++ 4 files changed, 354 insertions(+) create mode 100644 src/axolotl/integrations/spectrum/LICENSE create mode 100644 src/axolotl/integrations/spectrum/README.md create mode 100644 src/axolotl/integrations/spectrum/__init__.py create mode 100644 src/axolotl/integrations/spectrum/args.py diff --git a/src/axolotl/integrations/spectrum/LICENSE b/src/axolotl/integrations/spectrum/LICENSE new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/src/axolotl/integrations/spectrum/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/axolotl/integrations/spectrum/README.md b/src/axolotl/integrations/spectrum/README.md new file mode 100644 index 0000000000..de17db5127 --- /dev/null +++ b/src/axolotl/integrations/spectrum/README.md @@ -0,0 +1,21 @@ +## Spectrum: Targeted Training on Signal to Noise Ratio + +by Eric Hartford, Lucas Atkins, et al. + +This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR). + +### Overview + +Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models. +By identifying the top n% of layers with the highest SNR, you can optimize training efficiency. + +### Usage + +```yaml +plugins: + - axolotl.integrations.spectrum.SpectrumPlugin + +spectrum_top_fraction: 0.5 +# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror +spectrum_model_name: meta-llama/Meta-Llama-3.1-8B +``` diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py new file mode 100644 index 0000000000..6059e7951c --- /dev/null +++ b/src/axolotl/integrations/spectrum/__init__.py @@ -0,0 +1,102 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spectrum Plugin to automatically generate unfrozen parameters based on SNR data. +""" + +import json +import logging + +import requests + +from axolotl.integrations.base import BasePlugin + +from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 + + +def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): + unfrozen_parameters = {} + for layer_name, info in snr_data.items(): + layer_type = info["type"] + if layer_type not in unfrozen_parameters: + unfrozen_parameters[layer_type] = [] + unfrozen_parameters[layer_type].append((layer_name, info["snr"])) + top_layers_by_type = {} + for layer_type, layers in unfrozen_parameters.items(): + layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True) + num_top_layers = int(len(layers) * top_fraction) + top_layers_by_type[layer_type] = [ + layer[0] for layer in layers_sorted[:num_top_layers] + ] + unfrozen_parameters = [ + "^lm_head.weight$", + "^model.embed_tokens.weight$", + ] + for layer_type, layer_names in top_layers_by_type.items(): + for layer_name in layer_names: + unfrozen_parameters.append(layer_name) + return unfrozen_parameters + + +class SpectrumPlugin(BasePlugin): + """ + Spectrum Plugin to automatically generate unfrozen parameters based on SNR data. + """ + + base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/" + base_path = "./model_snr_results/" + snr_file_template = "snr_results_{model_name_slug}.json" + + def get_input_args(self): + return "axolotl.integrations.spectrum.SpectrumArgs" + + def pre_model_load(self, cfg): + if cfg.get("spectrum_model_name"): + model_name = cfg["spectrum_model_name"] + else: + model_name = cfg["base_model"] + top_fraction = cfg.get("spectrum_top_fraction", 50) + model_slug = model_name.replace("/", "-").replace("_", "-") + snr_url = self.base_url + self.snr_file_template.format( + model_name_slug=model_slug + ) + snr_path = self.base_path + self.snr_file_template.format( + model_name_slug=model_slug + ) + # first check if the files exist locally and read the json + snr_data = None + try: + with open(snr_path, "r", encoding="utf-8") as fin: + snr_data = json.load(fin) + except FileNotFoundError: + pass + except Exception as exc: # pylint: disable=broad-exception-caught + logging.warning(f"Failed to read SNR data from {snr_path}: {exc}") + + if not snr_data: + try: + snr_data = requests.get(snr_url, timeout=60).json() + except requests.exceptions.RequestException as exc: + logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") + return + # also catch json parsing errors + except json.JSONDecodeError as exc: + logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}") + return + + unfrozen_parameters = _generate_unfrozen_params_yaml( + snr_data, top_fraction=top_fraction + ) + cfg["unfrozen_parameters"] = unfrozen_parameters diff --git a/src/axolotl/integrations/spectrum/args.py b/src/axolotl/integrations/spectrum/args.py new file mode 100644 index 0000000000..03426d8413 --- /dev/null +++ b/src/axolotl/integrations/spectrum/args.py @@ -0,0 +1,29 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for handling Spectrum input arguments. +""" +from typing import Optional + +from pydantic import BaseModel + + +class SpectrumArgs(BaseModel): + """ + Input args for Spectrum. + """ + + spectrum_top_fraction: Optional[float] = 0.5 + spectrum_model_name: Optional[str] = None From 6819c12cee9248defe5bb5d6690aa4bfc03e5351 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 Aug 2024 12:00:36 -0400 Subject: [PATCH 25/89] update specturm authors (#1869) --- src/axolotl/integrations/spectrum/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/integrations/spectrum/README.md b/src/axolotl/integrations/spectrum/README.md index de17db5127..192918060e 100644 --- a/src/axolotl/integrations/spectrum/README.md +++ b/src/axolotl/integrations/spectrum/README.md @@ -1,6 +1,6 @@ ## Spectrum: Targeted Training on Signal to Noise Ratio -by Eric Hartford, Lucas Atkins, et al. +by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR). From 2dac1edf7225cc75bd781d78a7d0cca33bd8560f Mon Sep 17 00:00:00 2001 From: Chiwan Park Date: Tue, 27 Aug 2024 01:56:12 +0900 Subject: [PATCH 26/89] Fix `drop_long_seq` bug due to truncation in prompt tokenization strategies when using `chat_template` (#1867) --- src/axolotl/prompt_strategies/chat_template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8ae668d7e9..19e36531a5 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -350,7 +350,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), - "max_length": cfg.sequence_len, + # we need to add one for detecting sequences with exceeding the `sequence_len` limit. + "max_length": cfg.sequence_len + 1, } strategy_params = { From 17af1d7081414c32614cbabe324e1197ca9f43a7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 Aug 2024 15:50:26 -0400 Subject: [PATCH 27/89] clear cuda cache to help with memory leak/creep (#1858) * clear cuda cache to help with memory leak/creep * reverse order of gc --- src/axolotl/core/trainer_builder.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1a073ca047..656ded2559 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,6 +4,7 @@ """ import abc +import gc import importlib import importlib.util import logging @@ -15,11 +16,12 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -997,6 +999,14 @@ def tokenize_row( res[key] = res[key][1:] return res + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs) + gc.collect() + torch.cuda.empty_cache() + return loss + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ From f6362d2a05cf87e9c1615d503751e059f1b110d0 Mon Sep 17 00:00:00 2001 From: Chiwan Park Date: Wed, 28 Aug 2024 02:03:16 +0900 Subject: [PATCH 28/89] Add Liger Kernal support for Qwen2 (#1871) --- src/axolotl/integrations/liger/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d4c1ad9a4d..bf4c83af4f 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -23,6 +23,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.model.llama import lce_forward +from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -102,3 +103,17 @@ def pre_model_load(self, cfg): modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + + elif cfg.model_config_type == "qwen2": + from transformers.models.qwen2 import modeling_qwen2 + + if cfg.liger_rope: + modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward From 1e4366070179f238a40b7ef8356be744350c1a38 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 27 Aug 2024 13:39:24 -0400 Subject: [PATCH 29/89] Sample pack trust remote code v2 (#1873) * fix the multipack patch for remote code models * add deepseek v2 lite example w fsdp --- examples/deepseek-v2/fft-fsdp-16b.yaml | 67 ++++++++++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 2 + src/axolotl/monkeypatch/utils.py | 2 - 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 examples/deepseek-v2/fft-fsdp-16b.yaml diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml new file mode 100644 index 0000000000..b55646df7f --- /dev/null +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -0,0 +1,67 @@ +base_model: deepseek-ai/DeepSeek-V2-Lite +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb473..529c42a8f5 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -94,3 +94,5 @@ def patch_remote(model_name, config_name, modeling_name): module_name = model_config.__class__.__module__.replace(config_name, modeling_name) modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access + # workaround to make the patch stick + modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index e43c58650a..f29f21be77 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: max_num = int(torch.max(attention_mask).item()) batch_size, _ = attention_mask.shape counts = torch.zeros((batch_size, max_num), dtype=torch.int32) - for i in range(1, max_num + 1): mask = attention_mask == i counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) - result = counts.flatten() nonzero_indices = torch.nonzero(result).squeeze(-1) return result[nonzero_indices] From 159b8b9a74af6745a2495138c1cdcaf0cc666ab8 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Tue, 27 Aug 2024 17:22:26 -0700 Subject: [PATCH 30/89] monkey-patch transformers to simplify monkey-patching modeling code (#1877) * monkey-patch transformers so that monkey-patched modeling code doesnt get overwritten * unnecessary now * add comment --- src/axolotl/monkeypatch/multipack.py | 2 - .../transformers_dynamic_module_utils.py | 51 +++++++++++++++++++ src/axolotl/utils/models.py | 5 ++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 src/axolotl/monkeypatch/transformers_dynamic_module_utils.py diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 529c42a8f5..44fc4cb473 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -94,5 +94,3 @@ def patch_remote(model_name, config_name, modeling_name): module_name = model_config.__class__.__module__.replace(config_name, modeling_name) modeling_arch = importlib.import_module(module_name) modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access - # workaround to make the patch stick - modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access diff --git a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py new file mode 100644 index 0000000000..dfc3e29c5a --- /dev/null +++ b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py @@ -0,0 +1,51 @@ +"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk""" + +import importlib +import os +import sys +import typing +from pathlib import Path + +from transformers.file_utils import HF_MODULES_CACHE + + +def _patched_get_class_in_module( + class_name: str, module_path: typing.Union[str, os.PathLike] +) -> typing.Type: + """ + Import a module on the cache directory for modules and extract a class from it. + + Args: + class_name (`str`): The name of the class to import. + module_path (`str` or `os.PathLike`): The path to the module to import. + + Returns: + `typing.Type`: The class looked for. + """ + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_spec = importlib.util.spec_from_file_location( + name, location=Path(HF_MODULES_CACHE) / module_path + ) + module = sys.modules.get(name) + if module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + # load in initial case only + module_spec.loader.exec_module(module) + return getattr(module, class_name) + + +def patch_transformers_dynamic_module_utils(): + """ + Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True. + This causes monkey-patches for multipack and liger to be removed. + We replace the original function with a version that does not reload the module from disk. + See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581 + """ + import transformers + + transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e183301991..e0526fb048 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,6 +43,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.transformers_dynamic_module_utils import ( + patch_transformers_dynamic_module_utils, +) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -54,6 +57,8 @@ LOG = logging.getLogger("axolotl") +patch_transformers_dynamic_module_utils() + # copied from accelerator.FullyShardedDataParallelPlugin def get_module_class_from_name(module, name): From c1a61ae23c8967952e445fef0185e13d72d28dd2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Tue, 27 Aug 2024 20:08:26 -0700 Subject: [PATCH 31/89] fix liger plugin load issues (#1876) --- src/axolotl/integrations/liger/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index bf4c83af4f..f78083300d 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -22,8 +22,7 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP -from liger_kernel.transformers.model.llama import lce_forward -from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward +from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -54,7 +53,7 @@ def pre_model_load(self, cfg): if cfg.liger_cross_entropy: modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss elif cfg.liger_fused_linear_cross_entropy: - modeling_llama.LlamaForCausalLM.forward = lce_forward + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward elif cfg.model_config_type == "mistral": from transformers.models.mistral import modeling_mistral @@ -105,6 +104,9 @@ def pre_model_load(self, cfg): modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward elif cfg.model_config_type == "qwen2": + from liger_kernel.transformers.model.qwen2 import ( + lce_forward as qwen2_lce_forward, + ) from transformers.models.qwen2 import modeling_qwen2 if cfg.liger_rope: From 7037e3c836960b315b469ab5659c1695b2fe0583 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Tue, 27 Aug 2024 20:52:40 -0700 Subject: [PATCH 32/89] deepseekv2 liger support (#1878) * deepseekv2 liger support * add comment * add missing impl --- src/axolotl/integrations/liger/__init__.py | 26 ++++ .../integrations/liger/models/deepseekv2.py | 127 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 src/axolotl/integrations/liger/models/deepseekv2.py diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index f78083300d..2a3e95163b 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,6 +19,7 @@ It is designed to be performant, correct, and light-weight. """ import logging +import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -119,3 +120,28 @@ def pre_model_load(self, cfg): modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + logging.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_cross_entropy: + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py new file mode 100644 index 0000000000..79fb274360 --- /dev/null +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -0,0 +1,127 @@ +""" +DeepseekV2 model with LigerFusedLinearCrossEntropyLoss +""" +# pylint: disable=duplicate-code + +from typing import List, Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + + +# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) +# @replace_return_docstrings( +# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +# ) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From e3a38450ded395feb8ef1bc227d09b8b476d18f9 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Thu, 29 Aug 2024 05:19:18 -0700 Subject: [PATCH 33/89] Add liger kernel to features (#1881) [skip ci] --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index af604fad50..c84f1cb8c9 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Features: - Supports fullfinetune, lora, qlora, relora, and gptq - Customize configurations using a simple yaml file or CLI overwrite - Load different dataset formats, use custom formats, or bring your own tokenized datasets -- Integrated with xformer, flash attention, rope scaling, and multipacking +- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud - Log results and optionally checkpoints to wandb or mlflow From ce33e1ed839cc15b9351d3567ab44076cd4809c6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 30 Aug 2024 17:51:18 -0400 Subject: [PATCH 34/89] pin liger-kernel to latest 0.2.1 (#1882) [skip ci] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f5fb547a26..b8d0a388b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel +liger-kernel==0.2.1 mamba-ssm==1.2.0.post1 From 15408d0f09062463a50d39cb0ff356fa1590a2b3 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sat, 31 Aug 2024 18:59:48 -0700 Subject: [PATCH 35/89] Update supported models for Liger Kernel (#1875) * Update supported models for Liger Kernel Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE * move import to their appropriate conditions * Integrate Phi3 LCE support https://github.com/linkedin/Liger-Kernel/pull/103/ --------- Co-authored-by: Wing Lian --- src/axolotl/integrations/liger/__init__.py | 51 +++++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 2a3e95163b..d58349932b 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -23,7 +23,6 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP -from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -43,6 +42,9 @@ def get_input_args(self): def pre_model_load(self, cfg): if cfg.model_config_type == "llama": + from liger_kernel.transformers.model.llama import ( + lce_forward as llama_lce_forward, + ) from transformers.models.llama import modeling_llama if cfg.liger_rope: @@ -57,6 +59,9 @@ def pre_model_load(self, cfg): modeling_llama.LlamaForCausalLM.forward = llama_lce_forward elif cfg.model_config_type == "mistral": + from liger_kernel.transformers.model.mistral import ( + lce_forward as mistral_lce_forward, + ) from transformers.models.mistral import modeling_mistral if cfg.liger_rope: @@ -68,11 +73,12 @@ def pre_model_load(self, cfg): if cfg.liger_cross_entropy: modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Mistral." - ) + modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward elif cfg.model_config_type == "gemma": + from liger_kernel.transformers.model.gemma import ( + lce_forward as gemma_lce_forward, + ) from transformers.models.gemma import modeling_gemma if cfg.liger_rope: @@ -84,9 +90,7 @@ def pre_model_load(self, cfg): if cfg.liger_cross_entropy: modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Gemma." - ) + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba @@ -145,3 +149,36 @@ def pre_model_load(self, cfg): modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + + elif cfg.model_config_type == "gemma2": + from transformers.models.gemma2 import modeling_gemma2 + + if cfg.liger_rope: + modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + if cfg.liger_cross_entropy: + modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Gemma 2." + ) + + elif cfg.model_config_type == "phi3": + from liger_kernel.transformers.model.phi3 import ( + lce_forward as phi3_lce_forward, + ) + from transformers.models.phi3 import modeling_phi3 + + if cfg.liger_rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_phi3.Phi3MLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward From 3c6b9eda2ecffb0204eb1c29635f75c0115317b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 31 Aug 2024 22:49:35 -0400 Subject: [PATCH 36/89] run pytests with varied pytorch versions too (#1883) --- .github/workflows/tests-nightly.yml | 5 +++++ .github/workflows/tests.yml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 6b35698cbf..30ed397cef 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,6 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] + pytorch_version: ["2.3.1", "2.4.0"] timeout-minutes: 20 steps: @@ -37,6 +38,10 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: Install PyTorch + run: | + pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + - name: Update requirements.txt run: | sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 74b4bcfbdb..c104e92c27 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,6 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] + pytorch_version: ["2.3.1", "2.4.0"] timeout-minutes: 20 steps: @@ -48,6 +49,10 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: Install PyTorch + run: | + pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + - name: Install dependencies run: | pip3 install --upgrade pip From bdab3ec587f9e85f5684cf57302a1bf93e78de1e Mon Sep 17 00:00:00 2001 From: Chiwan Park Date: Mon, 2 Sep 2024 07:34:24 +0900 Subject: [PATCH 37/89] Fix RMSNorm monkey patch for Gemma models (#1886) --- src/axolotl/integrations/liger/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d58349932b..2047f3815d 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -20,6 +20,7 @@ """ import logging import sys +from functools import partial from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -84,7 +85,9 @@ def pre_model_load(self, cfg): if cfg.liger_rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma.GemmaRMSNorm = LigerRMSNorm + modeling_gemma.GemmaRMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if cfg.liger_cross_entropy: @@ -156,7 +159,9 @@ def pre_model_load(self, cfg): if cfg.liger_rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm + modeling_gemma2.Gemma2RMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP if cfg.liger_cross_entropy: From 0aeb277456f0ed79ab46191a12998fccc257d414 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 1 Sep 2024 19:29:37 -0400 Subject: [PATCH 38/89] add e2e smoke tests for llama liger integration (#1884) * add e2e smoke tests for llama liger integration * fix import * don't use __main__ for test * consolidate line --- cicd/cicd.sh | 4 +- tests/e2e/integrations/__init__.py | 0 tests/e2e/integrations/liger.py | 110 +++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 tests/e2e/integrations/__init__.py create mode 100644 tests/e2e/integrations/liger.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index eceda9b375..104a8f84ab 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,5 +2,5 @@ set -e pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ -pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ -pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/ +pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ +pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/tests/e2e/integrations/__init__.py b/tests/e2e/integrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py new file mode 100644 index 0000000000..4497cebe32 --- /dev/null +++ b/tests/e2e/integrations/liger.py @@ -0,0 +1,110 @@ +""" +Simple end-to-end test for Liger integration +""" + +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + + +class LigerIntegrationTestCase(unittest.TestCase): + """ + e2e tests for liger integration with Axolotl + """ + + @with_temp_dir + def test_llama_wo_flce(self, temp_dir): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "plugins": [ + "axolotl.integrations.liger.LigerPlugin", + ], + "liger_rope": True, + "liger_rms_norm": True, + "liger_swiglu": True, + "liger_cross_entropy": True, + "liger_fused_linear_cross_entropy": False, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + + @with_temp_dir + def test_llama_w_flce(self, temp_dir): + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "plugins": [ + "axolotl.integrations.liger.LigerPlugin", + ], + "liger_rope": True, + "liger_rms_norm": True, + "liger_swiglu": True, + "liger_cross_entropy": False, + "liger_fused_linear_cross_entropy": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() From 4e5400c732c6b8baf2d0f7a700ce773b1501b1fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 3 Sep 2024 20:02:44 -0400 Subject: [PATCH 39/89] support for auto_find_batch_size when packing (#1885) * support for auto_find_batch_size when packing * make sure to return data from validation * make sure to return data from validation * actually expose multipack_real_batches in the config * calculate gathered efficiency in sampler * tweak to fix auto find and use actual sampler len for multipack * uncomment * use args for bsz when not available from auto find --- src/axolotl/core/trainer_builder.py | 15 ++++--- .../config/models/input/v0_4_1/__init__.py | 3 ++ src/axolotl/utils/samplers/multipack.py | 40 +++++++++++++++++-- src/axolotl/utils/trainer.py | 2 +- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 656ded2559..f4cd257838 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -506,9 +506,10 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: batch_max_len = self.args.max_seq_length else: batch_size = 1 - batch_max_len = ( - self.args.per_device_train_batch_size * self.args.max_seq_length + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size ) + batch_max_len = train_batch_size * self.args.max_seq_length return MultipackBatchSampler( RandomSampler(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset), @@ -1379,6 +1380,10 @@ def build(self, total_num_steps): training_arguments_kwargs[ "per_device_eval_batch_size" ] = self.cfg.eval_batch_size + if self.cfg.auto_find_batch_size is not None: + training_arguments_kwargs[ + "auto_find_batch_size" + ] = self.cfg.auto_find_batch_size training_arguments_kwargs[ "gradient_accumulation_steps" ] = self.cfg.gradient_accumulation_steps @@ -1461,9 +1466,9 @@ def build(self, total_num_steps): ) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) - training_arguments_kwargs[ - "multipack_real_batches" - ] = not self.cfg.flash_attention + training_arguments_kwargs["multipack_real_batches"] = ( + not self.cfg.flash_attention or self.cfg.multipack_real_batches + ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 65a2c5409a..9044047cce 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -355,6 +355,8 @@ class HyperparametersConfig(BaseModel): }, ) + auto_find_batch_size: Optional[bool] = None + train_on_inputs: Optional[bool] = False group_by_length: Optional[bool] = None @@ -592,6 +594,7 @@ class Config: eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None curriculum_sampling: Optional[bool] = None + multipack_real_batches: Optional[bool] = None # for PoSE context length extension use_pose: Optional[bool] = None diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 957ca57464..205c2894d1 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -11,6 +11,8 @@ import numpy as np from torch.utils.data import BatchSampler, Sampler +from axolotl.utils.distributed import reduce_and_broadcast + LOG = logging.getLogger("axolotl.utils.samplers.multipack") @@ -174,16 +176,46 @@ def num_batches(self): def efficiency(self): return self.eff_total_used / self.eff_total_slots + def gather_efficiency(self): + def calc_sample_packing_eff_est(estimates: List[float]): + LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") + return math.floor(0.997 * max(estimates)) + + sample_packing_actual_eff_all = reduce_and_broadcast( + lambda: self.efficiency(), # pylint: disable=unnecessary-lambda + calc_sample_packing_eff_est, + ) + sample_packing_eff_est = ( + math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 + ) + return sample_packing_eff_est + + def gather_len_batches(self, num): + def calc_min_len(estimates: list[(int, float)]): + LOG.info(f"gather_len_batches: {repr(estimates)}") + return math.floor(0.998 * min(estimates)) + + min_len_batches = reduce_and_broadcast( + lambda: num, + calc_min_len, + ) + return min_len_batches + def __len__(self): - self.num_batches() - return self._len_est() + len_batches = self.num_batches() + return self.gather_len_batches(len_batches) def _len_est(self): + efficiency = ( + self.packing_efficiency_estimate + if self.packing_efficiency_estimate + else self.gather_efficiency() + ) world_size = int(os.getenv("WORLD_SIZE", "1")) lengths_sum = np.sum(self.lengths) lengths_sum_per_device = lengths_sum // world_size LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"packing_efficiency_estimate: {efficiency} " f"total_num_tokens per device: {lengths_sum_per_device}" ) @@ -195,7 +227,7 @@ def _len_est(self): * math.floor( 0.99 * lengths_sum_per_device - / self.packing_efficiency_estimate + / efficiency // (self.batch_max_len * self.batch_size) ) - 1 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f4e1fc6cb8..1029fff13d 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -357,7 +357,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): main_process_only=True, ) else: - if cfg.flash_attention: + if cfg.flash_attention and not cfg.multipack_real_batches: sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: From dca1fe47d44d69c4558729070a3716b6f79e555c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 4 Sep 2024 11:28:47 -0400 Subject: [PATCH 40/89] fix optimizer + fsdp combination in example (#1893) --- examples/llama-3/fft-8b-liger-fsdp.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index a64965d207..e84d221f85 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -31,7 +31,7 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 1 -optimizer: paged_adamw_8bit +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 2e-5 From f18f4268b53be3c9518e2223b1f0aa874d5938c3 Mon Sep 17 00:00:00 2001 From: Tijmen de Haan Date: Thu, 5 Sep 2024 18:33:19 +0900 Subject: [PATCH 41/89] Docs for AMD-based HPC systems (#1891) * Add documentation for installing on AMD-based HPC systems. * Accept suggestion to add note about deepspeed Co-authored-by: NanoCode012 * Update _quarto.yml with amd_hpc doc --------- Co-authored-by: Tijmen de Haan Co-authored-by: NanoCode012 --- _quarto.yml | 1 + docs/amd_hpc.qmd | 108 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 docs/amd_hpc.qmd diff --git a/_quarto.yml b/_quarto.yml index 6b2eed971b..acb4872589 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -37,6 +37,7 @@ website: - docs/mac.qmd - docs/multi-node.qmd - docs/unsloth.qmd + - docs/amd_hpc.qmd - section: "Dataset Formats" contents: docs/dataset-formats/* - section: "Reference" diff --git a/docs/amd_hpc.qmd b/docs/amd_hpc.qmd new file mode 100644 index 0000000000..92eadee03a --- /dev/null +++ b/docs/amd_hpc.qmd @@ -0,0 +1,108 @@ +--- +title: Training with AMD GPUs on HPC Systems +description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs +--- + +This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs. + +## Setup + +### 1. Install Python + +We recommend using Miniforge, a minimal conda-based Python distribution: + +```bash +curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" +bash Miniforge3-$(uname)-$(uname -m).sh +``` + +### 2. Configure Python Environment +Add Python to your PATH and ensure it's available at login: + +```bash +echo 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc +echo 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile +``` + +### 3. Load AMD GPU Software + +Load the ROCm module: + +```bash +module load rocm/5.7.1 +``` + +Note: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name. + +### 4. Install PyTorch + +Install PyTorch with ROCm support: + +```bash +pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall +``` + +### 5. Install Flash Attention + +Clone and install the Flash Attention repository: + +```bash +git clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git +export GPU_ARCHS="gfx90a" +cd flash-attention +export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') +patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch +pip install . +``` + +### 6. Install Axolotl + +Clone and install Axolotl: + +```bash +git clone https://github.com/axolotl-ai-cloud/axolotl +cd axolotl +pip install packaging ninja +pip install -e . +``` + +### 7. Apply xformers Workaround + +xformers appears to be incompatible with ROCm. Apply the following workarounds: + - Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers. + - Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the "SwiGLU" function with a pass statement. + +### 8. Prepare Job Submission Script + +Create a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include + +```bash +export TRANSFORMERS_OFFLINE=1 +export HF_DATASETS_OFFLINE=1 +``` + +### 9. Download Base Model + +Download a base model using the Hugging Face CLI: + +```bash +huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B +``` + +### 10. Create Axolotl Configuration + +Create an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training. + +Note: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know. + +### 11. Preprocess Data + +Run preprocessing on the login node: + +```bash +CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.yaml +``` + +### 12. Train + +You are now ready to submit your previously prepared job script. 🚂 \ No newline at end of file From 93b769a9792db5908885537ed42fb7eef80f0f1c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 5 Sep 2024 09:58:21 -0400 Subject: [PATCH 42/89] lint fix and update gha regex (#1899) --- .github/workflows/lint.yml | 2 +- docs/amd_hpc.qmd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 671be4b652..919cfd6545 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,7 +6,7 @@ on: - '**.py' - 'requirements.txt' - '.github/workflows/*.yml' - - "*.md" + - "*.[q]md" - "examples/**/*.y[a]?ml" workflow_dispatch: diff --git a/docs/amd_hpc.qmd b/docs/amd_hpc.qmd index 92eadee03a..d1c274e15a 100644 --- a/docs/amd_hpc.qmd +++ b/docs/amd_hpc.qmd @@ -105,4 +105,4 @@ CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.ya ### 12. Train -You are now ready to submit your previously prepared job script. 🚂 \ No newline at end of file +You are now ready to submit your previously prepared job script. 🚂 From ab461d83c4b78df70d310ce45e33ef145796611d Mon Sep 17 00:00:00 2001 From: Alpay Ariyak <98838263+alpayariyak@users.noreply.github.com> Date: Thu, 5 Sep 2024 07:11:31 -0700 Subject: [PATCH 43/89] Fix documentation for pre-tokenized dataset (#1894) It's currently asking to not add BOS and EOS, stating that Axolotl adds them, but this is not true --- docs/dataset-formats/tokenized.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dataset-formats/tokenized.qmd b/docs/dataset-formats/tokenized.qmd index b2ea003c02..61028cae7f 100644 --- a/docs/dataset-formats/tokenized.qmd +++ b/docs/dataset-formats/tokenized.qmd @@ -7,7 +7,7 @@ order: 5 - Pass an empty `type:` in your axolotl config. - Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels` - To indicate that a token should be ignored during training, set its corresponding label to `-100`. -- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using. +- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100. - For pretraining, do not truncate/pad documents to the context window length. - For instruction training, documents must be truncated/padded as desired. From 6e354682e3c1735d3f7fb9e362280c38e922260f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 5 Sep 2024 10:58:50 -0400 Subject: [PATCH 44/89] fix zero3 integration (#1897) * fix zero3 integration * bump transformers and accelerate too --- requirements.txt | 4 ++-- src/axolotl/utils/trainer.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index b8d0a388b3..c61216e63b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.12.0 -transformers==4.44.0 +transformers==4.44.2 tokenizers>=0.19.1 bitsandbytes==0.43.3 -accelerate==0.33.0 +accelerate==0.34.0 datasets==2.20.0 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1029fff13d..89ae4e6970 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -425,7 +425,8 @@ def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" - HfTrainerDeepSpeedConfig(cfg.deepspeed) + # If we don't assign this, it doesn't actually get set in the accelerate weakref + _ = HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg): From 3853ab7ae9220dfbd78cd628e54fde75fb89df97 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 7 Sep 2024 14:39:31 -0400 Subject: [PATCH 45/89] bump accelerate to 0.34.2 (#1901) * bump accelerate * add fixture to predownload the test model * change fixture --- .github/workflows/multi-gpu-e2e.yml | 3 +++ requirements.txt | 2 +- tests/e2e/multigpu/test_llama.py | 7 +++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 91cbaf957e..ab886c67f1 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -1,6 +1,9 @@ name: docker-multigpu-tests-biweekly on: + pull_request: + paths: + - 'tests/e2e/multigpu/*.py' workflow_dispatch: schedule: - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday diff --git a/requirements.txt b/requirements.txt index c61216e63b..83116af60f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ peft==0.12.0 transformers==4.44.2 tokenizers>=0.19.1 bitsandbytes==0.43.3 -accelerate==0.34.0 +accelerate==0.34.2 datasets==2.20.0 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 344c57fb85..61bb8ed327 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -10,6 +10,7 @@ import pytest import yaml from accelerate.test_utils import execute_subprocess_async +from huggingface_hub import snapshot_download from axolotl.utils.dict import DictDefault @@ -19,6 +20,12 @@ os.environ["WANDB_DISABLED"] = "true" +@pytest.fixture(scope="session", autouse=True) +def download_model(): + # download the model + snapshot_download("TinyLlama/TinyLlama_v1.1") + + class TestMultiGPULlama(unittest.TestCase): """ Test case for Llama models using LoRA From 5c42f114115cc4e2dd49c1da2437ef1c08aecf69 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Sep 2024 22:19:54 -0400 Subject: [PATCH 46/89] remove dynamic module loader monkeypatch as this was fixed upstream (#1914) --- examples/deepseek-v2/qlora-fsdp-2_5.yaml | 83 +++++++++++++++++++ requirements.txt | 4 +- .../transformers_dynamic_module_utils.py | 51 ------------ src/axolotl/utils/models.py | 5 -- 4 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 examples/deepseek-v2/qlora-fsdp-2_5.yaml delete mode 100644 src/axolotl/monkeypatch/transformers_dynamic_module_utils.py diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml new file mode 100644 index 0000000000..6e82062d66 --- /dev/null +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -0,0 +1,83 @@ +base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16 +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + + +plugins: + - axolotl.integrations.liger.LigerPlugin +liger_rms_norm: true +liger_swiglu: true +liger_fused_linear_cross_entropy: true + +chat_template: deepseek_v2 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +adapter: qlora +lora_r: 256 +lora_alpha: 256 +lora_target_linear: true +peft_use_rslora: true + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 100 +evals_per_epoch: 2 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +special_tokens: +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD diff --git a/requirements.txt b/requirements.txt index 83116af60f..32a9e0e01c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.12.0 -transformers==4.44.2 +transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992 tokenizers>=0.19.1 bitsandbytes==0.43.3 accelerate==0.34.2 -datasets==2.20.0 +datasets==2.21.0 deepspeed==0.14.4 pydantic==2.6.3 addict diff --git a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py b/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py deleted file mode 100644 index dfc3e29c5a..0000000000 --- a/src/axolotl/monkeypatch/transformers_dynamic_module_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Patch transformers.dynamic_module_utils.get_class_in_module to avoid reloading models from disk""" - -import importlib -import os -import sys -import typing -from pathlib import Path - -from transformers.file_utils import HF_MODULES_CACHE - - -def _patched_get_class_in_module( - class_name: str, module_path: typing.Union[str, os.PathLike] -) -> typing.Type: - """ - Import a module on the cache directory for modules and extract a class from it. - - Args: - class_name (`str`): The name of the class to import. - module_path (`str` or `os.PathLike`): The path to the module to import. - - Returns: - `typing.Type`: The class looked for. - """ - name = os.path.normpath(module_path) - if name.endswith(".py"): - name = name[:-3] - name = name.replace(os.path.sep, ".") - module_spec = importlib.util.spec_from_file_location( - name, location=Path(HF_MODULES_CACHE) / module_path - ) - module = sys.modules.get(name) - if module is None: - module = importlib.util.module_from_spec(module_spec) - # insert it into sys.modules before any loading begins - sys.modules[name] = module - # load in initial case only - module_spec.loader.exec_module(module) - return getattr(module, class_name) - - -def patch_transformers_dynamic_module_utils(): - """ - Recently, transformers started reloading modeling code from disk for models marked trust_remote_code=True. - This causes monkey-patches for multipack and liger to be removed. - We replace the original function with a version that does not reload the module from disk. - See https://github.com/huggingface/transformers/pull/30370#pullrequestreview-2264361581 - """ - import transformers - - transformers.dynamic_module_utils.get_class_in_module = _patched_get_class_in_module diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e0526fb048..e183301991 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -43,9 +43,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) -from axolotl.monkeypatch.transformers_dynamic_module_utils import ( - patch_transformers_dynamic_module_utils, -) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -57,8 +54,6 @@ LOG = logging.getLogger("axolotl") -patch_transformers_dynamic_module_utils() - # copied from accelerator.FullyShardedDataParallelPlugin def get_module_class_from_name(module, name): From 7b9f669a3ab18aecb00b17e7f2885aeb458440c8 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Sat, 14 Sep 2024 05:22:54 -0700 Subject: [PATCH 47/89] Trigger the original tokenization behavior when no advanced turn settings are provided (#1915) --- examples/phi/lora-3.5.yaml | 76 ++ .../prompt_strategies/chat_template.py | 54 +- src/axolotl/utils/chat_templates.py | 1 + .../config/models/input/v0_4_1/__init__.py | 1 + tests/prompt_strategies/conftest.py | 71 ++ .../prompt_strategies/test_chat_templates.py | 714 ++---------------- .../test_chat_templates_advanced.py | 615 +++++++++++++++ 7 files changed, 866 insertions(+), 666 deletions(-) create mode 100644 examples/phi/lora-3.5.yaml create mode 100644 tests/prompt_strategies/conftest.py create mode 100644 tests/prompt_strategies/test_chat_templates_advanced.py diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml new file mode 100644 index 0000000000..59d667b8db --- /dev/null +++ b/examples/phi/lora-3.5.yaml @@ -0,0 +1,76 @@ +base_model: microsoft/Phi-3.5-mini-instruct +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: phi_3 +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + chat_template: phi_3 + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 2 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bfloat16: true +bf16: true +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 4 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 19e36531a5..717367eefa 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -24,8 +24,8 @@ def __init__( max_length=2048, message_field_role: str = "from", message_field_content: str = "value", - message_field_training: str = "train", - message_field_training_detail: str = "train_detail", + message_field_training: Optional[str] = None, + message_field_training_detail: Optional[str] = None, roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -186,7 +186,7 @@ def __init__( train_on_inputs, sequence_len, roles_to_train=None, - train_on_eos="last", + train_on_eos=None, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] @@ -201,6 +201,37 @@ def messages(self, messages): self._messages = messages def tokenize_prompt(self, prompt): + # Old simple legacy behavior that works reliably. + if ( + not self.roles_to_train + and not self.train_on_eos + and not self.prompter.message_field_training + and not self.prompter.message_field_training_detail + ): + turns = self.get_conversation_thread(prompt) + prompt_ids = self.prompter.build_prompt( + turns[:-1], add_generation_prompt=True + ) + input_ids = self.prompter.build_prompt(turns) + + if not self.train_on_inputs: + user_prompt_len = len(prompt_ids) + labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] + else: + labels = input_ids + + tokenized_prompt = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(input_ids), + } + + return tokenized_prompt + LOG.info(self.roles_to_train) + LOG.info(self.train_on_eos) + LOG.info(self.prompter.message_field_training) + LOG.info(self.prompter.message_field_training_detail) + turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) labels = [IGNORE_TOKEN_ID] * len(input_ids) @@ -219,9 +250,11 @@ def tokenize_prompt(self, prompt): should_train = ( train_turn if train_turn is not None - else bool(train_detail is not None) - if train_detail is not None - else self.train_on_inputs or role in self.roles_to_train + else ( + bool(train_detail is not None) + if train_detail is not None + else self.train_on_inputs or role in self.roles_to_train + ) ) LOG.debug(f"Should train: {should_train}") @@ -344,9 +377,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), "message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", + None, ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -357,8 +391,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = ChatTemplateStrategy( diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 51f88b1bdf..7a96f5c1e1 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -26,6 +26,7 @@ def chat_templates(user_choice: str): "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 9044047cce..458bacdb12 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -189,6 +189,7 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py new file mode 100644 index 0000000000..43423f7255 --- /dev/null +++ b/tests/prompt_strategies/conftest.py @@ -0,0 +1,71 @@ +""" +shared fixtures for prompt strategies tests +""" + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + return Dataset.from_list( + [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "goodbye"}, + {"role": "assistant", "content": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="sharegpt_dataset") +def fixture_sharegpt_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "human", "value": "hello"}, + {"from": "gpt", "value": "hello"}, + {"from": "human", "value": "goodbye"}, + {"from": "gpt", "value": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="basic_dataset") +def fixture_basic_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "system", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "assistant", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "assistant", "value": "I'm doing well, thank you!"}, + ] + } + ] + ) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") + + return tokenizer + + +@pytest.fixture(name="phi35_tokenizer") +def fixture_phi35_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") + return tokenizer diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index e2fc0f6a52..28210b7ae8 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -5,10 +5,6 @@ import logging import unittest -import pytest -from datasets import Dataset -from transformers import AutoTokenizer - from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, @@ -22,657 +18,6 @@ LOG = logging.getLogger("axolotl") -@pytest.fixture(name="assistant_dataset") -def fixture_assistant_dataset(): - return Dataset.from_list( - [ - { - "messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hello"}, - {"role": "user", "content": "goodbye"}, - {"role": "assistant", "content": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="sharegpt_dataset") -def fixture_sharegpt_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "human", "value": "hello"}, - {"from": "gpt", "value": "hello"}, - {"from": "human", "value": "goodbye"}, - {"from": "gpt", "value": "goodbye"}, - ] - } - ] - ) - - -@pytest.fixture(name="basic_dataset") -def fixture_basic_dataset(): - # pylint: disable=duplicate-code - return Dataset.from_list( - [ - { - "conversations": [ - {"from": "system", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "assistant", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "assistant", "value": "I'm doing well, thank you!"}, - ] - } - ] - ) - - -@pytest.fixture(name="llama3_tokenizer") -def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") - - return tokenizer - - -class TestChatTemplateConfigurations: - """ - Test class for various configurations of ChatTemplateStrategy. - """ - - @staticmethod - def find_sublist(full_list, sub_list): - token_count = len(sub_list) - for index in range(len(full_list) - token_count + 1): - if full_list[index : index + token_count] == sub_list: - return index - return -1 - - def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Check the behavior of human inputs - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - labeled = all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ) - LOG.debug( - f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" - ) - - LOG.debug("Full labels: %s", labels) - LOG.debug("Full input_ids: %s", input_ids) - - def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_inputs=False") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - # Verify that human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - LOG.debug( - f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" - ) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(input_ids)] - ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" - - def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with assistant only") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that only assistant responses are labeled - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing roles_to_train with all roles") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=True, - sequence_len=512, - roles_to_train=["human", "assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Verify that all responses are labeled (except for special tokens) - all_responses = [ - "Hello", - "Hi there!", - "How are you?", - "I'm doing well, thank you!", - ] - for response in all_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - LOG.debug( - f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" - ) - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" - - def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with empty roles_to_train") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - train_on_eos="none", # Add this line - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - - # Verify that no labels are set when roles_to_train is empty - LOG.debug("Full labels: %s", labels) - assert all( - label == IGNORE_TOKEN_ID for label in labels - ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" - - def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='all'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="all", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to be labeled" - - def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='turn'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="turn", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - assistant_responses = ["Hi there!", "I'm doing well, thank you!"] - - for response in assistant_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find '{response}' in input_ids" - - eos_idx = start_idx + len(response_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert eos_idx < len( - input_ids - ), f"Could not find EOS token after '{response}'" - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token after assistant response '{response}' to be labeled" - - # Check that EOS tokens after human inputs are not labeled - human_inputs = ["Hello", "How are you?"] - for input_text in human_inputs: - input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, input_ids) - assert start_idx != -1, f"Could not find '{input_text}' in input_ids" - - eos_idx = start_idx + len(input_ids) - while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: - eos_idx += 1 - - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token after human input '{input_text}' to not be labeled" - - def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='last'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="last", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - last_eos_idx = eos_indices[-1] - - # Check that only the last EOS token is labeled - for idx in eos_indices[:-1]: - assert ( - labels[idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {idx} to not be labeled" - assert ( - labels[last_eos_idx] != IGNORE_TOKEN_ID - ), f"Expected last EOS token at index {last_eos_idx} to be labeled" - - def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with train_on_eos='none'") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - train_on_eos="none", - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - eos_token_id = llama3_tokenizer.eos_token_id - eos_indices = [ - i for i, token_id in enumerate(input_ids) if token_id == eos_token_id - ] - - assert len(eos_indices) > 0, "Expected at least one EOS token in the input" - for eos_idx in eos_indices: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to not be labeled" - - def test_drop_system_message(self, llama3_tokenizer, basic_dataset): - LOG.info("Testing with drop_system_message=True") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["assistant"], - ) - res = strategy.tokenize_prompt(basic_dataset[0]) - input_ids = res["input_ids"] - - # Check if system message is not present in input_ids - system_message = "You are an AI assistant." - system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) - assert ( - self.find_sublist(input_ids, system_ids) == -1 - ), "Expected system message to be dropped" - - def test_custom_roles(self, llama3_tokenizer): - LOG.info("Testing with custom roles mapping") - custom_roles = { - "user": ["human", "user"], - "assistant": ["ai", "assistant"], - "system": ["context"], - } - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=["ai"], - ) - - # Create a new dataset with modified role names - modified_conversations = [ - {"from": "context", "value": "You are an AI assistant."}, - {"from": "human", "value": "Hello"}, - {"from": "ai", "value": "Hi there!"}, - {"from": "human", "value": "How are you?"}, - {"from": "ai", "value": "I'm doing well, thank you!"}, - ] - - modified_dataset = Dataset.from_dict( - {"conversations": [modified_conversations]} - ) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Check if AI responses are labeled correctly - ai_responses = ["Hi there!", "I'm doing well, thank you!"] - for response in ai_responses: - response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, response_ids) - assert start_idx != -1, f"Could not find response '{response}' in input_ids" - assert all( - label != IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(response_ids)] - ), f"Expected labels for AI response '{response}' to be set" - - # Check if human messages are not labeled - human_messages = ["Hello", "How are you?"] - for message in human_messages: - message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) - start_idx = self.find_sublist(input_ids, message_ids) - assert start_idx != -1, f"Could not find message '{message}' in input_ids" - assert all( - label == IGNORE_TOKEN_ID - for label in labels[start_idx : start_idx + len(message_ids)] - ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" - - def test_message_field_training(self, llama3_tokenizer): - LOG.info("Testing with message_field_training") - strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - llama3_tokenizer, - chat_templates("llama3"), - message_field_training="train", - message_field_training_detail="train_detail", - ), - tokenizer=llama3_tokenizer, - train_on_inputs=False, - sequence_len=512, - roles_to_train=[], - ) - - # Create a new dataset with the train and train_detail fields - modified_conversation = [ - {"from": "system", "value": "You are an AI assistant.", "train": False}, - {"from": "human", "value": "Hello", "train": False}, - {"from": "assistant", "value": "Hello", "train": True}, - {"from": "human", "value": "How are you?", "train": True}, - { - "from": "assistant", - "value": "I'm doing very well, thank you!", - "train_detail": [ - {"begin_offset": 0, "end_offset": 8, "train": False}, - {"begin_offset": 9, "end_offset": 18, "train": True}, - {"begin_offset": 19, "end_offset": 30, "train": False}, - ], - }, - { - "from": "human", - "value": "I'm doing very well, thank you!", - "train": False, - }, - {"from": "assistant", "value": "Hi there!", "train": True}, - ] - - modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) - - res = strategy.tokenize_prompt(modified_dataset[0]) - labels = res["labels"] - input_ids = res["input_ids"] - - # Function to find all occurrences of a sublist - def find_all_sublists(full_list, sub_list): - indices = [] - for index in range(len(full_list) - len(sub_list) + 1): - if full_list[index : index + len(sub_list)] == sub_list: - indices.append(index) - return indices - - # Keep track of which occurrences we've processed - processed_occurrences = {} - # Check if messages are labeled correctly based on train or train_detail - for i, turn in enumerate(modified_conversation): - turn_tokens = llama3_tokenizer.encode( - turn["value"], add_special_tokens=False - ) - occurrences = find_all_sublists(input_ids, turn_tokens) - turn_key = turn["value"] - if turn_key not in processed_occurrences: - processed_occurrences[turn_key] = 0 - current_occurrence = processed_occurrences[turn_key] - - if current_occurrence >= len(occurrences): - assert ( - False - ), f"Not enough occurrences found for message: {turn['value']}" - - start_idx = occurrences[current_occurrence] - processed_occurrences[turn_key] += 1 - end_idx = start_idx + len(turn_tokens) - - LOG.debug( - f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" - ) - - if "train_detail" in turn: - # Get token offsets - tokenized_output = llama3_tokenizer( - turn["value"], return_offsets_mapping=True, add_special_tokens=False - ) - token_offsets = tokenized_output["offset_mapping"] - - # Adjust token offsets as done in the implementation - for i in range(len(token_offsets) - 1): - token_offsets[i] = ( - token_offsets[i][0], - token_offsets[i + 1][0] - 1, - ) - token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) - - # Adjust train_details - adjusted_train_details = strategy.prompter.adjust_train_details( - turn["train_detail"], token_offsets - ) - - LOG.debug(f"Original train_details: {turn['train_detail']}") - LOG.debug(f"Adjusted train_details: {adjusted_train_details}") - - # Handle train_detail - token_offsets = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=False, - ) - token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( - text=turn["value"], - train_details=adjusted_train_details, - mask_untrainable=True, - ) - LOG.debug(f"Token offsets: {token_offsets_masked}") - - expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) - for i, offset in enumerate(token_offsets_masked): - if offset != IGNORE_TOKEN_ID: - expected_labels[i] = turn_tokens[i] - actual_labels = labels[ - start_idx : start_idx + len(token_offsets_masked) - ] - assert ( - actual_labels == expected_labels - ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" - - for detail in adjusted_train_details: - # Find the token indices that correspond to the character offsets - detail_start = start_idx + next( - i - for i, offset in enumerate(token_offsets) - if offset >= detail["begin_offset"] - ) - detail_end = start_idx + next( - ( - i - for i, offset in enumerate(token_offsets) - if offset > detail["end_offset"] - ), - len(token_offsets), - ) - - detail_text = turn["value"][ - detail["begin_offset"] : detail["end_offset"] + 1 - ] - detail_labels = labels[detail_start:detail_end] - detail_input_ids = input_ids[detail_start:detail_end] - - LOG.debug( - f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" - ) - LOG.debug(f"Detail input_ids: {detail_input_ids}") - LOG.debug(f"Detail labels: {detail_labels}") - LOG.debug( - f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" - ) - LOG.debug( - f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" - ) - - if detail["train"]: - assert all( - label != IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - assert all( - label == IGNORE_TOKEN_ID for label in detail_labels - ), ( - f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " - f"Labels({detail_start}:{detail_end}): {detail_labels}, " - f"InputIDs: {detail_input_ids}, " - f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" - ) - else: - should_train = turn.get("train", False) - turn_labels = labels[start_idx:end_idx] - - LOG.debug(f"Should train: {should_train}") - LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") - LOG.debug(f"Turn labels: {turn_labels}") - LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") - LOG.debug( - f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" - ) - - if should_train: - assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be set\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - else: - assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( - f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" - f"Labels({start_idx}:{end_idx}): {turn_labels}, " - f"InputIDs: {input_ids[start_idx:end_idx]}, " - f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" - ) - - LOG.debug( - f"Processed turn: {turn['from']}, content: '{turn['value']}', " - f"start_idx: {start_idx}, end_idx: {end_idx}, " - f"labels: {labels[start_idx:end_idx]}" - ) - - LOG.debug(f"Final labels: {labels}") - LOG.debug(f"Final input_ids: {input_ids}") - - class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -740,7 +85,6 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, - roles_to_train=["assistant"], ) strategy.messages = "messages" res = strategy.tokenize_prompt(assistant_dataset[0]) @@ -764,6 +108,64 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): input_ids == expected_input_ids ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + def test_phi35(self, phi35_tokenizer, assistant_dataset): + LOG.info("Testing phi-3.5 with assistant dataset") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + phi35_tokenizer, + chat_templates("phi_35"), + message_field_role="role", + message_field_content="content", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + tokenizer=phi35_tokenizer, + train_on_inputs=False, + sequence_len=512, + ) + strategy.messages = "messages" + res = strategy.tokenize_prompt(assistant_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 32010, # user + 22172, 32007, # user eot + 32001, # assistant + 22172, 32007, # assistant eot + 32010, # user + 1781, 26966, 32007, # user eot + 32001, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + expected_labels = [ + -100, # user + -100, -100, # user eot + -100, # assistant + -100, -100, # assistant eot, + -100, # user + -100, -100, -100, # user eot + -100, # assistant + 1781, 26966, 32007, # assistant eot + 32000, # eos + ] + # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + LOG.debug(f"Expected labels : {expected_labels}") + LOG.debug(f"Actual labels : {labels}") + assert ( + labels == expected_labels + ), f"Input IDs mismatch: {labels} != {expected_labels}" + def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): LOG.info("Testing llama-3 with assistant dataset including training data") strategy = ChatTemplateStrategy( diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py new file mode 100644 index 0000000000..f18fb39423 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -0,0 +1,615 @@ +""" +tests for chat_template prompt strategy +""" + +import logging +import unittest + +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.utils.chat_templates import chat_templates + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestChatTemplateConfigurations: + """ + Test class for various configurations of ChatTemplateStrategy. + """ + + @staticmethod + def find_sublist(full_list, sub_list): + token_count = len(sub_list) + for index in range(len(full_list) - token_count + 1): + if full_list[index : index + token_count] == sub_list: + return index + return -1 + + def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Check the behavior of human inputs + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + labeled = all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ) + LOG.debug( + f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" + ) + + LOG.debug("Full labels: %s", labels) + LOG.debug("Full input_ids: %s", input_ids) + + def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=False") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Verify that human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + LOG.debug( + f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" + + def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with assistant only") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with all roles") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["human", "assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that all responses are labeled (except for special tokens) + all_responses = [ + "Hello", + "Hi there!", + "How are you?", + "I'm doing well, thank you!", + ] + for response in all_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with empty roles_to_train") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + train_on_eos="none", # Add this line + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + + # Verify that no labels are set when roles_to_train is empty + LOG.debug("Full labels: %s", labels) + assert all( + label == IGNORE_TOKEN_ID for label in labels + ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + + def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='all'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="all", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to be labeled" + + def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='turn'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + + eos_idx = start_idx + len(response_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert eos_idx < len( + input_ids + ), f"Could not find EOS token after '{response}'" + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token after assistant response '{response}' to be labeled" + + # Check that EOS tokens after human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + + eos_idx = start_idx + len(input_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token after human input '{input_text}' to not be labeled" + + def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='last'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="last", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + last_eos_idx = eos_indices[-1] + + # Check that only the last EOS token is labeled + for idx in eos_indices[:-1]: + assert ( + labels[idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {idx} to not be labeled" + assert ( + labels[last_eos_idx] != IGNORE_TOKEN_ID + ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + + def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='none'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to not be labeled" + + def test_drop_system_message(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with drop_system_message=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + + # Check if system message is not present in input_ids + system_message = "You are an AI assistant." + system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) + assert ( + self.find_sublist(input_ids, system_ids) == -1 + ), "Expected system message to be dropped" + + def test_custom_roles(self, llama3_tokenizer): + LOG.info("Testing with custom roles mapping") + custom_roles = { + "user": ["human", "user"], + "assistant": ["ai", "assistant"], + "system": ["context"], + } + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["ai"], + ) + + # Create a new dataset with modified role names + modified_conversations = [ + {"from": "context", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "ai", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "ai", "value": "I'm doing well, thank you!"}, + ] + + modified_dataset = Dataset.from_dict( + {"conversations": [modified_conversations]} + ) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Check if AI responses are labeled correctly + ai_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in ai_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find response '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for AI response '{response}' to be set" + + # Check if human messages are not labeled + human_messages = ["Hello", "How are you?"] + for message in human_messages: + message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, message_ids) + assert start_idx != -1, f"Could not find message '{message}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(message_ids)] + ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" + + def test_message_field_training(self, llama3_tokenizer): + LOG.info("Testing with message_field_training") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_training="train", + message_field_training_detail="train_detail", + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + ) + + # Create a new dataset with the train and train_detail fields + modified_conversation = [ + {"from": "system", "value": "You are an AI assistant.", "train": False}, + {"from": "human", "value": "Hello", "train": False}, + {"from": "assistant", "value": "Hello", "train": True}, + {"from": "human", "value": "How are you?", "train": True}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": False}, + {"begin_offset": 9, "end_offset": 18, "train": True}, + {"begin_offset": 19, "end_offset": 30, "train": False}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": False, + }, + {"from": "assistant", "value": "Hi there!", "train": True}, + ] + + modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Function to find all occurrences of a sublist + def find_all_sublists(full_list, sub_list): + indices = [] + for index in range(len(full_list) - len(sub_list) + 1): + if full_list[index : index + len(sub_list)] == sub_list: + indices.append(index) + return indices + + # Keep track of which occurrences we've processed + processed_occurrences = {} + # Check if messages are labeled correctly based on train or train_detail + for i, turn in enumerate(modified_conversation): + turn_tokens = llama3_tokenizer.encode( + turn["value"], add_special_tokens=False + ) + occurrences = find_all_sublists(input_ids, turn_tokens) + turn_key = turn["value"] + if turn_key not in processed_occurrences: + processed_occurrences[turn_key] = 0 + current_occurrence = processed_occurrences[turn_key] + + if current_occurrence >= len(occurrences): + assert ( + False + ), f"Not enough occurrences found for message: {turn['value']}" + + start_idx = occurrences[current_occurrence] + processed_occurrences[turn_key] += 1 + end_idx = start_idx + len(turn_tokens) + + LOG.debug( + f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" + ) + + if "train_detail" in turn: + # Get token offsets + tokenized_output = llama3_tokenizer( + turn["value"], return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokenized_output["offset_mapping"] + + # Adjust token offsets as done in the implementation + for i in range(len(token_offsets) - 1): + token_offsets[i] = ( + token_offsets[i][0], + token_offsets[i + 1][0] - 1, + ) + token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) + + # Adjust train_details + adjusted_train_details = strategy.prompter.adjust_train_details( + turn["train_detail"], token_offsets + ) + + LOG.debug(f"Original train_details: {turn['train_detail']}") + LOG.debug(f"Adjusted train_details: {adjusted_train_details}") + + # Handle train_detail + token_offsets = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=False, + ) + token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=True, + ) + LOG.debug(f"Token offsets: {token_offsets_masked}") + + expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) + for i, offset in enumerate(token_offsets_masked): + if offset != IGNORE_TOKEN_ID: + expected_labels[i] = turn_tokens[i] + actual_labels = labels[ + start_idx : start_idx + len(token_offsets_masked) + ] + assert ( + actual_labels == expected_labels + ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + + for detail in adjusted_train_details: + # Find the token indices that correspond to the character offsets + detail_start = start_idx + next( + i + for i, offset in enumerate(token_offsets) + if offset >= detail["begin_offset"] + ) + detail_end = start_idx + next( + ( + i + for i, offset in enumerate(token_offsets) + if offset > detail["end_offset"] + ), + len(token_offsets), + ) + + detail_text = turn["value"][ + detail["begin_offset"] : detail["end_offset"] + 1 + ] + detail_labels = labels[detail_start:detail_end] + detail_input_ids = input_ids[detail_start:detail_end] + + LOG.debug( + f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" + ) + LOG.debug(f"Detail input_ids: {detail_input_ids}") + LOG.debug(f"Detail labels: {detail_labels}") + LOG.debug( + f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" + ) + LOG.debug( + f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" + ) + + if detail["train"]: + assert all( + label != IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + assert all( + label == IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + should_train = turn.get("train", False) + turn_labels = labels[start_idx:end_idx] + + LOG.debug(f"Should train: {should_train}") + LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") + LOG.debug( + f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" + ) + + if should_train: + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be set\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + else: + assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + + LOG.debug( + f"Processed turn: {turn['from']}, content: '{turn['value']}', " + f"start_idx: {start_idx}, end_idx: {end_idx}, " + f"labels: {labels[start_idx:end_idx]}" + ) + + LOG.debug(f"Final labels: {labels}") + LOG.debug(f"Final input_ids: {input_ids}") + + +if __name__ == "__main__": + unittest.main() From d7eea2ff343e9f0653ce82fc1826b41efa9bc2f6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 24 Sep 2024 14:05:58 -0400 Subject: [PATCH 48/89] validation fixes 20240923 (#1925) * validation fixes 20240923 * fix run name for wandb and defaults for chat template fields * fix gradio inference with llama chat template --- src/axolotl/cli/__init__.py | 27 +++++++++++++++++-- src/axolotl/core/trainer_builder.py | 8 ++++++ .../prompt_strategies/chat_template.py | 4 +-- .../config/models/input/v0_4_1/__init__.py | 10 ++++++- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index aaa62423ca..13c5b4ab58 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,6 +30,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -234,7 +235,8 @@ def do_inference_gradio( model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + # default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + default_tokens: Dict[str, str] = {} for token, symbol in default_tokens.items(): # If the token isn't already specified in the config, add it @@ -242,10 +244,13 @@ def do_inference_gradio( tokenizer.add_special_tokens({token: symbol}) prompter_module = None + chat_template_str = None if prompter: prompter_module = getattr( importlib.import_module("axolotl.prompters"), prompter ) + elif cfg.chat_template: + chat_template_str = chat_templates(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) @@ -259,7 +264,24 @@ def generate(instruction): ) else: prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): @@ -282,6 +304,7 @@ def generate(instruction): streamer = TextIteratorStreamer(tokenizer) generation_kwargs = { "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), "generation_config": generation_config, "streamer": streamer, } diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f4cd257838..7c3e437f80 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1417,6 +1417,8 @@ def build(self, total_num_steps): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") + if self.cfg.wandb_name: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name if self.cfg.use_mlflow: report_to.append("mlflow") if self.cfg.use_tensorboard: @@ -1574,6 +1576,12 @@ def build(self, total_num_steps): ) training_args = self.hook_post_create_training_args(training_args) + # unset run_name so wandb sets up experiment names + if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: + training_args.run_name = ( # pylint: disable=attribute-defined-outside-init + None + ) + data_collator_kwargs = { "padding": True, # True/"longest" is the default } diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 717367eefa..88e748895d 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): prompter_params = { "tokenizer": tokenizer, "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( "message_field_training_detail", diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 458bacdb12..2217855083 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1017,12 +1017,20 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha): return neftune_noise_alpha @model_validator(mode="after") - def check(self): + def check_rl_beta(self): if self.dpo_beta and not self.rl_beta: self.rl_beta = self.dpo_beta del self.dpo_beta return self + @model_validator(mode="after") + def check_simpo_warmup(self): + if self.rl == "simpo" and self.warmup_ratio: + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) + return self + @model_validator(mode="before") @classmethod def check_frozen(cls, data): From b98d7d7098f5d64a07c5a96855c4e08dca7afd91 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 26 Sep 2024 11:33:41 -0400 Subject: [PATCH 49/89] update upstream deps versions and replace lora+ (#1928) * update upstream deps versions and replace lora+ * typo transformers version --- requirements.txt | 8 +- src/axolotl/core/trainer_builder.py | 14 +-- src/axolotl/loraplus.py | 133 ---------------------------- 3 files changed, 11 insertions(+), 144 deletions(-) delete mode 100644 src/axolotl/loraplus.py diff --git a/requirements.txt b/requirements.txt index 32a9e0e01c..3f17e5d329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.12.0 -transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992 +peft==0.13.0 +transformers==4.45.0 tokenizers>=0.19.1 -bitsandbytes==0.43.3 +bitsandbytes==0.44.0 accelerate==0.34.2 datasets==2.21.0 deepspeed==0.14.4 @@ -34,7 +34,7 @@ tensorboard python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 -liger-kernel==0.2.1 +liger-kernel==0.3.0 mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7c3e437f80..23ac0952ed 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -21,6 +21,7 @@ import torch import transformers from datasets import Dataset +from peft.optimizers import create_loraplus_optimizer from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -45,7 +46,6 @@ ) from trl.trainer.utils import pad_to_length -from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils import is_mlflow_available @@ -461,9 +461,9 @@ def create_optimizer(self): self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, ) elif self.args.alternate_optimizer == "optimi_adamw": from optimi import AdamW @@ -969,9 +969,9 @@ def create_optimizer(self): self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, ) if is_sagemaker_mp_enabled(): diff --git a/src/axolotl/loraplus.py b/src/axolotl/loraplus.py deleted file mode 100644 index b4abec55ad..0000000000 --- a/src/axolotl/loraplus.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Module for LoRA+""" - -# MIT License -# -# Copyright (c) 2024 nikhil-ghosh-berkeley -# https://github.com/nikhil-ghosh-berkeley/loraplus - -import logging -from functools import reduce - -from peft.tuners import lora -from torch import nn -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.trainer_pt_utils import get_parameter_names - -LOG = logging.getLogger("axolotl.loraplus") - - -def get_module(name, opt_model): - """ - Retrieve a module from a model using its parameter name. - Args: - name (str): Full name of the parameter, typically including module path. - opt_model (torch.nn.Module): The model from which to retrieve the module. - - Returns: - Module corresponding to the given name. - """ - parent_idx = 2 if "lora" in name else 1 - module_names = name.split(sep=".")[:-parent_idx] - module = reduce(getattr, module_names, opt_model) - return module - - -def create_loraplus_optimizer( - opt_model, - optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding=None, -): - """ - Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups. - - Args: - opt_model (torch.nn.Module): The model for which the optimizer is being created. - optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam). - optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization. - loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters. - loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided. - - Returns: - An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates. - """ - - assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided." - - if loraplus_lr_embedding is None: - loraplus_lr_embedding = 1e-6 - - decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - param_groups = { - "groupA": {}, - "groupB": {}, - "groupB_no_decay": {}, - "embedding": {}, - } - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - - module = get_module(name, opt_model) - if isinstance(module, lora.Embedding): - param_groups["embedding"][name] = param - elif "lora_B" in name or param.ndim == 1: - if name in decay_parameters: - param_groups["groupB"][name] = param - else: - param_groups["groupB_no_decay"][name] = param - else: - param_groups["groupA"][name] = param - - assigned_param_groups = "" - for group, group_params in param_groups.items(): - assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n" - LOG.info(assigned_param_groups) - - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - weight_decay = optimizer_kwargs.get("weight_decay", 0.0) - - optimizer_grouped_parameters = [ - { - "params": list(param_groups["groupA"].values()), - "weight_decay": weight_decay, - "lr": lr, - }, - { - "params": list(param_groups["embedding"].values()), - "weight_decay": weight_decay, - "lr": loraplus_lr_embedding, - }, - { - "params": list(param_groups["groupB"].values()), - "weight_decay": weight_decay, - "lr": lr * loraplus_lr_ratio, - }, - { - "params": list(param_groups["groupB_no_decay"].values()), - "weight_decay": 0.0, - "lr": lr * loraplus_lr_ratio, - }, - ] - - optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum( - {p.data_ptr(): p.numel() for p in module.parameters()}.values() - ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") - manager.register_module_override(module, "weight", {"optim_bits": 32}) - LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") - - return optimizer From 61aa291119e90dbebf5612be42cd5cad3729bc6e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 27 Sep 2024 15:58:35 -0400 Subject: [PATCH 50/89] fix for empty lora+ lr embedding (#1932) --- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 23ac0952ed..249398f850 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -456,7 +456,7 @@ def create_optimizer(self): if self.args.loraplus_lr_ratio is not None: loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", None + self.args, "loraplus_lr_embedding", 1e-6 ) self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init opt_model, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2217855083..4e07c9260a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -298,6 +298,13 @@ def validate_qlora(self): raise ValueError("Require cfg.load_in_4bit to be True for qlora") return self + @field_validator("loraplus_lr_embedding") + @classmethod + def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): + if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str): + loraplus_lr_embedding = float(loraplus_lr_embedding) + return loraplus_lr_embedding + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" From 844331005c1ef45430ff26b9f42f757dce6ee66a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Sep 2024 13:56:12 -0400 Subject: [PATCH 51/89] bump transformers to 4.45.1 (#1936) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3f17e5d329..123a4ee54a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.0 -transformers==4.45.0 +transformers==4.45.1 tokenizers>=0.19.1 bitsandbytes==0.44.0 accelerate==0.34.2 From e1915f5625b2330555c3f61816dd003fb939ae13 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Oct 2024 21:02:48 -0400 Subject: [PATCH 52/89] Multimodal Vision Llama - rudimentary support (#1940) --------- Co-authored-by: Sunny Co-authored-by: sunny --- docs/input_output.qmd | 2 +- docs/multimodal.qmd | 28 +++ examples/llama-3-vision/lora-11b.yaml | 63 +++++ src/axolotl/cli/__init__.py | 7 +- src/axolotl/core/trainer_builder.py | 20 +- src/axolotl/monkeypatch/attention/mllama.py | 229 ++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 1 + .../monkeypatch/stablelm_attn_hijack_flash.py | 1 + src/axolotl/prompt_strategies/__init__.py | 4 +- .../prompt_strategies/chat_template.py | 60 ++++- src/axolotl/train.py | 10 +- src/axolotl/utils/chat_templates.py | 46 ++-- src/axolotl/utils/collators/__init__.py | 10 + .../{collators.py => collators/batching.py} | 35 +-- src/axolotl/utils/collators/core.py | 4 + src/axolotl/utils/collators/mamba.py | 38 +++ src/axolotl/utils/collators/mm_chat.py | 77 ++++++ src/axolotl/utils/config/__init__.py | 25 +- .../config/models/input/v0_4_1/__init__.py | 29 ++- src/axolotl/utils/data/sft.py | 48 +++- src/axolotl/utils/models.py | 98 ++++++-- src/axolotl/utils/trainer.py | 16 +- .../prompt_strategies/test_chat_templates.py | 21 +- .../test_chat_templates_advanced.py | 46 +++- 24 files changed, 799 insertions(+), 119 deletions(-) create mode 100644 docs/multimodal.qmd create mode 100644 examples/llama-3-vision/lora-11b.yaml create mode 100644 src/axolotl/monkeypatch/attention/mllama.py create mode 100644 src/axolotl/utils/collators/__init__.py rename src/axolotl/utils/{collators.py => collators/batching.py} (90%) create mode 100644 src/axolotl/utils/collators/core.py create mode 100644 src/axolotl/utils/collators/mamba.py create mode 100644 src/axolotl/utils/collators/mm_chat.py diff --git a/docs/input_output.qmd b/docs/input_output.qmd index 7715dd250d..6559578d18 100644 --- a/docs/input_output.qmd +++ b/docs/input_output.qmd @@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/') hi there!. goodbye farewell ``` -We can check that the right tokens are ingored by comparing the labels +We can check that the right tokens are ignored by comparing the labels to each token: ```python diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd new file mode 100644 index 0000000000..2381566adb --- /dev/null +++ b/docs/multimodal.qmd @@ -0,0 +1,28 @@ +# MultiModal / Vision Language Models (BETA) + +### Supported Models + +- Mllama, i.e. llama with vision models + +### Usage + +Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA, +you'll need to use the following in YAML in combination with the rest of the required hyperparams. + +```yaml +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +skip_prepare_dataset: true + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +remove_unused_columns: false +sample_packing: false + +# only finetune the Language model, leave the vision model and vision tower frozen +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +``` diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml new file mode 100644 index 0000000000..b2e4946418 --- /dev/null +++ b/examples/llama-3-vision/lora-11b.yaml @@ -0,0 +1,63 @@ +base_model: alpindale/Llama-3.2-11B-Vision-Instruct +processor_type: AutoProcessor +strict: false + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: llama3_2_vision +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +local_rank: +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 13c5b4ab58..a1d84b6a16 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -40,7 +40,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_tokenizer +from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -430,9 +430,12 @@ def load_datasets( cli_args: TrainerCliArgs, ) -> TrainDatasetMeta: tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, tokenizer + cfg, + tokenizer, + processor=processor, ) if cli_args.debug or cfg.debug: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 249398f850..4893e63dc2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -61,12 +61,14 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) +from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -250,6 +252,10 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate lr scheduler to the HF trainer" }, ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) @dataclass @@ -1043,10 +1049,11 @@ class TrainerBuilderBase(abc.ABC): _model_ref = None _peft_config = None - def __init__(self, cfg, model, tokenizer): + def __init__(self, cfg, model, tokenizer, processor=None): self.cfg = cfg self.model = model self.tokenizer = tokenizer + self.processor = processor # in case the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls @@ -1515,6 +1522,10 @@ def build(self, total_num_steps): ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) + if self.cfg.chat_template: + training_arguments_kwargs["chat_template"] = chat_templates( + self.cfg.chat_template + ) if self.cfg.rl == "orpo": training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha @@ -1661,7 +1672,12 @@ def build_collator( else: collator = BatchSamplerDataCollatorForSeq2Seq else: - collator = DataCollatorForSeq2Seq + if self.cfg.processor_type and self.processor: + collator = MultiModalChatDataCollator + kwargs["processor"] = self.processor + kwargs["chat_template"] = training_args.chat_template + else: + collator = DataCollatorForSeq2Seq return collator( self.tokenizer, diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py new file mode 100644 index 0000000000..0b18b716d5 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/mllama.py @@ -0,0 +1,229 @@ +""" +Monkeypatch for Vision Llama for FA2 support +""" +# pylint: disable=duplicate-code + +from typing import Optional, Tuple + +import torch +from flash_attn.flash_attn_interface import flash_attn_func +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.models.mllama.configuration_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import ( + MllamaTextCrossAttention, + MllamaTextSelfAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import is_flash_attn_greater_or_equal_2_10 + + +class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention): + """ + Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[ # pylint: disable=unused-argument + torch.Tensor + ] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, -1, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, + value_states, + self.layer_idx, + {"cache_position": cache_position}, + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # Apply Flash Attention + dropout_rate = self.dropout if self.training else 0.0 + output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=dropout_rate, + softmax_scale=None, + causal=False, + return_attn_probs=output_attentions, + ) + + attn_output = output.contiguous().view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention): + """ + Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and + implements the forward pass using Flash Attention for improved performance. + """ + + def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs): + super().__init__(config, layer_idx, *args, **kwargs) + + # Check if flash attention version is greater or equal to 2.1 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + past_key_value=None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x num_heads x head_dim + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transpose to get the expected layout for flash attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # Handle potential silent casting to float32 + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = ( + self.config._pre_quantization_dtype # pylint: disable=protected-access + ) + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def patch_mllama(): + from transformers.models.mllama.modeling_mllama import ( + MLLAMA_TEXT_ATTENTION_CLASSES, + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES, + MLLAMA_VISION_ATTENTION_CLASSES, + MllamaPreTrainedModel, + ) + + MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access + True + ) + MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MllamaTextCrossFlashAttention2 + # fallback to SDPA + MLLAMA_VISION_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb473..85101cd3c4 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,7 @@ from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "mllama_text_model", "llama", "mistral", "mixtral", diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 0269f90157..67e9337e36 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -16,6 +16,7 @@ # This code is based off the following work: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# pylint: disable=duplicate-code """ PyTorch StableLM Epoch model. """ import importlib import math diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index f5699a0871..66cd5deeb9 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -9,7 +9,7 @@ LOG = logging.getLogger("axolotl.prompt_strategies") -def load(strategy, tokenizer, cfg, ds_cfg): +def load(strategy, tokenizer, cfg, ds_cfg, processor=None): try: load_fn = "load" if strategy.split(".")[-1].startswith("load_"): @@ -24,6 +24,8 @@ def load(strategy, tokenizer, cfg, ds_cfg): sig = inspect.signature(func) if "ds_cfg" in sig.parameters: load_kwargs["ds_cfg"] = ds_cfg + if "processor" in sig.parameters: + load_kwargs["processor"] = processor return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 88e748895d..48d52dae11 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -5,6 +5,8 @@ import logging from typing import Any, Dict, List, Optional +from transformers import ProcessorMixin + from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import chat_templates @@ -20,6 +22,7 @@ class ChatTemplatePrompter(Prompter): def __init__( self, tokenizer, + processor=None, chat_template=None, max_length=2048, message_field_role: str = "from", @@ -44,11 +47,12 @@ def __init__( self.message_field_training = message_field_training self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer + self.processor: ProcessorMixin = processor self.chat_template = chat_template self.max_length = max_length self.drop_system_message = drop_system_message - def build_prompt(self, conversation, add_generation_prompt=False): + def build_prompt(self, conversation, add_generation_prompt=False, images=None): turns = [ { "role": self.roles[t[self.message_field_role]], @@ -61,6 +65,28 @@ def build_prompt(self, conversation, add_generation_prompt=False): if self.drop_system_message and turns[0]["role"] == "system": turns = turns[1:] + if self.processor: + text = self.processor.apply_chat_template( + turns, + chat_template=self.chat_template, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + batch = self.processor( + text=text, + images=images, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ) + # workaround since processor works in batches instead of single examples + for k, val in batch.items(): + if k in ["pixel_values"]: + batch[k] = val.tolist() + else: + batch[k] = val.squeeze().tolist() + return batch + return self.tokenizer.apply_chat_template( turns, truncation=True, @@ -191,6 +217,7 @@ def __init__( super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.roles_to_train = roles_to_train if roles_to_train is not None else [] self.train_on_eos = train_on_eos + self.images = "images" @property def messages(self): @@ -209,10 +236,21 @@ def tokenize_prompt(self, prompt): and not self.prompter.message_field_training_detail ): turns = self.get_conversation_thread(prompt) + images = self.get_images(prompt) prompt_ids = self.prompter.build_prompt( - turns[:-1], add_generation_prompt=True + turns[:-1], + add_generation_prompt=True, + images=images, ) - input_ids = self.prompter.build_prompt(turns) + tokenized_res = self.prompter.build_prompt(turns, images=images) + tokenized_prompt = {} + if isinstance(tokenized_res, list): + input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] + tokenized_prompt["input_ids"] = input_ids + tokenized_prompt["attention_mask"] = [1] * len(input_ids) + else: + input_ids = tokenized_res["input_ids"] + tokenized_prompt = tokenized_res if not self.train_on_inputs: user_prompt_len = len(prompt_ids) @@ -220,17 +258,9 @@ def tokenize_prompt(self, prompt): else: labels = input_ids - tokenized_prompt = { - "input_ids": input_ids, - "labels": labels, - "attention_mask": [1] * len(input_ids), - } + tokenized_prompt["labels"] = labels return tokenized_prompt - LOG.info(self.roles_to_train) - LOG.info(self.train_on_eos) - LOG.info(self.prompter.message_field_training) - LOG.info(self.prompter.message_field_training_detail) turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) @@ -368,8 +398,11 @@ def find_turn(self, conversation_ids, turn, turn_content): def get_conversation_thread(self, prompt): return prompt[self.messages] + def get_images(self, prompt): + return prompt.get(self.images, None) + -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): ds_cfg = ds_cfg or {} prompter_params = { @@ -386,6 +419,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): "drop_system_message": ds_cfg.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. "max_length": cfg.sequence_len + 1, + "processor": processor, } strategy_params = { diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b21b0b269c..855dbc2d3b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -24,7 +24,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer try: @@ -69,6 +69,9 @@ def train( main_process_only=True, ) tokenizer = load_tokenizer(cfg) + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset @@ -96,7 +99,9 @@ def train( LOG.debug(msg) # we wait unitl the last possible moment to setup Accelerator Accelerator() - model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model, peft_config = load_model( + cfg, tokenizer, processor=processor, inference=cli_args.inference + ) model.generation_config.do_sample = True model_ref = None @@ -122,6 +127,7 @@ def train( eval_dataset, (model, model_ref, peft_config), tokenizer, + processor, total_num_steps, ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7a96f5c1e1..7468ae8b15 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,6 +3,20 @@ These templates are used for formatting messages in a conversation. """ +CHAT_TEMPLATES = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", + "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', +} + def chat_templates(user_choice: str): """ @@ -18,20 +32,22 @@ def chat_templates(user_choice: str): ValueError: If the user_choice is not found in the templates. """ - templates = { - "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", - "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', - } - - if user_choice in templates: - return templates[user_choice] + if user_choice in CHAT_TEMPLATES: + return CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") + + +def register_chat_template(template_name: str, chat_template: str): + """ + Registers chat templates. + + Args: + template_name (str): The name of the template. + chat_template (str): The template string. + """ + + if template_name in CHAT_TEMPLATES: + raise ValueError(f"Template '{template_name}' already exists.") + + CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py new file mode 100644 index 0000000000..93502b67d7 --- /dev/null +++ b/src/axolotl/utils/collators/__init__.py @@ -0,0 +1,10 @@ +""" +shared axolotl collators for multipack, mamba, multimodal +""" +from .batching import ( # noqa: F401 + BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForSeq2Seq, + PretrainingBatchSamplerDataCollatorForSeq2Seq, + V2BatchSamplerDataCollatorForSeq2Seq, +) +from .mamba import MambaDataCollator # noqa: F401 diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators/batching.py similarity index 90% rename from src/axolotl/utils/collators.py rename to src/axolotl/utils/collators/batching.py index 26c7fa9f3c..7cf771421c 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,17 +1,14 @@ """ DataCollator for axolotl to pad labels and position_ids for packed sequences """ + from dataclasses import dataclass -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Optional, Union import numpy as np -import torch -import transformers from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -IGNORE_INDEX = -100 - @dataclass class DataCollatorForSeq2Seq: @@ -183,34 +180,6 @@ def __call__(self, features, return_tensors=None): return super().__call__(out_features, return_tensors=return_tensors) -@dataclass -class MambaDataCollator: - """ - Collator for State Space Models (Mamba) - """ - - tokenizer: transformers.PreTrainedTokenizer - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple( - [torch.LongTensor(instance[key]) for instance in instances] - for key in ("input_ids", "labels") - ) - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id, - ) - labels = torch.nn.utils.rnn.pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX - ) - - return { - "input_ids": input_ids, - "labels": labels, - } - - @dataclass class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ diff --git a/src/axolotl/utils/collators/core.py b/src/axolotl/utils/collators/core.py new file mode 100644 index 0000000000..0eae0c3bda --- /dev/null +++ b/src/axolotl/utils/collators/core.py @@ -0,0 +1,4 @@ +""" +basic shared collator constants +""" +IGNORE_INDEX = -100 diff --git a/src/axolotl/utils/collators/mamba.py b/src/axolotl/utils/collators/mamba.py new file mode 100644 index 0000000000..0c4a22fcc0 --- /dev/null +++ b/src/axolotl/utils/collators/mamba.py @@ -0,0 +1,38 @@ +""" +collators for Mamba +""" +from dataclasses import dataclass +from typing import Dict, Sequence + +import torch +import transformers + +from axolotl.utils.collators.core import IGNORE_INDEX + + +@dataclass +class MambaDataCollator: + """ + Collator for State Space Models (Mamba) + """ + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [torch.LongTensor(instance[key]) for instance in instances] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + + return { + "input_ids": input_ids, + "labels": labels, + } diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py new file mode 100644 index 0000000000..f49e97f37f --- /dev/null +++ b/src/axolotl/utils/collators/mm_chat.py @@ -0,0 +1,77 @@ +""" +Collators for multi-modal chat messages and packing +""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers.data.data_collator import DataCollatorMixin +from transformers.utils import PaddingStrategy + + +@dataclass +class MultiModalChatDataCollator(DataCollatorMixin): + """ + Collator for multi-modal chat messages + """ + + tokenizer: PreTrainedTokenizerBase + processor: ProcessorMixin + return_tensors: str = "pt" + chat_template: Optional[str] = None + packing: bool = False + max_images: int = -1 + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + + def __post_init__(self): + if self.packing: + raise ValueError("Packing is currently not supported.") + + def torch_call( + self, examples: List[Union[List[int], Any, Dict[str, Any]]] + ) -> Dict[str, Any]: + # Handle dict or lists with proper padding and conversion to tensor. + + return self.__class__.process_rows( + examples, self.processor, self.chat_template, self.max_images + ) + + @staticmethod + def process_rows(examples, processor, chat_template, max_images, length_only=False): + # HINT: use `_torch_collate_batch` to stack and pad tensors + # see also DataCollatorWithFlattening and DefaultDataCollator + + # *** This is COPIED from the trl example sft_vlm.py code *** + # use this as a starting point + + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + example["messages"], chat_template=chat_template, tokenize=False + ) + for example in examples + ] + images = [example["images"] for example in examples] + + if max_images > 0: + images = [img_batch[:max_images] for img_batch in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + if length_only: + return { + "length": [len(sample["input_ids"]) for sample in batch["input_ids"]] + } + return batch diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 82436e8d79..f732db06fc 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -121,15 +121,36 @@ def normalize_config(cfg): cfg.base_model_config = cfg.base_model model_config = load_model_config(cfg) - cfg.model_config_type = model_config.model_type cfg.tokenizer_config = ( cfg.tokenizer_config or cfg.base_model_config or cfg.base_model ) + cfg.is_multimodal = ( + hasattr(model_config, "model_type") + and model_config.model_type in ["llava", "mllama"] + or any( + multimodal_name in cfg.base_model.lower() + for multimodal_name in [ + "pixtral", + ] + ) + or cfg.is_multimodal + ) + if cfg.is_multimodal: + cfg.processor_config = ( + cfg.processor_config or cfg.base_model_config or cfg.base_model + ) + model_config = model_config.text_config + + cfg.model_config_type = model_config.model_type + # figure out if the model is llama cfg.is_llama_derived_model = ( - (hasattr(model_config, "model_type") and model_config.model_type == "llama") + ( + hasattr(model_config, "model_type") + and model_config.model_type == ["llama", "mllama_text_model"] + ) or cfg.is_llama_derived_model or "llama" in cfg.base_model.lower() or (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4e07c9260a..fced5e639d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -188,6 +188,7 @@ class ChatTemplate(str, Enum): gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name @@ -228,11 +229,12 @@ class LoraConfig(BaseModel): lora_r: Optional[int] = None lora_alpha: Optional[int] = None lora_fan_in_fan_out: Optional[bool] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[str, List[str]]] = None lora_target_linear: Optional[bool] = None lora_modules_to_save: Optional[List[str]] = None lora_dropout: Optional[float] = 0.0 peft_layers_to_transform: Optional[List[int]] = None + peft_layers_pattern: Optional[List[str]] = None peft: Optional[PeftConfig] = None peft_use_dora: Optional[bool] = None peft_use_rslora: Optional[bool] = None @@ -328,6 +330,9 @@ class ModelInputConfig(BaseModel): tokenizer_type: Optional[str] = Field( default=None, metadata={"help": "transformers tokenizer class"} ) + processor_type: Optional[str] = Field( + default=None, metadata={"help": "transformers processor class"} + ) trust_remote_code: Optional[bool] = None model_kwargs: Optional[Dict[str, Any]] = None @@ -530,6 +535,7 @@ class Config: dataset_prepared_path: Optional[str] = None dataset_shard_num: Optional[int] = None dataset_shard_idx: Optional[int] = None + skip_prepare_dataset: Optional[bool] = False pretraining_dataset: Optional[ # type: ignore conlist(Union[PretrainingDataset, SFTDataset], min_length=1) @@ -997,6 +1003,18 @@ def check_eval_packing(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_mm_prepare(cls, data): + if data.get("skip_prepare_dataset"): + if data.get("remove_unused_columns") is None: + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) + data["remove_unused_columns"] = False + + return data + @model_validator(mode="before") @classmethod def check_warmup(cls, data): @@ -1052,6 +1070,15 @@ def check_frozen(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_peft_layers_pattern(cls, data): + if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) + return data + @model_validator(mode="after") def check_fft_possible_bad_config(self): if ( diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 1b6df1cded..7d6922cbf2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -51,20 +51,31 @@ LOG = logging.getLogger("axolotl") -def prepare_dataset(cfg, tokenizer): +def prepare_dataset(cfg, tokenizer, processor=None): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_local_main_process()): if cfg.test_datasets: train_dataset, _, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="train", + processor=processor, ) _, eval_dataset, _ = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test" + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + split="test", + processor=processor, ) else: train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + tokenizer, + cfg, + DEFAULT_DATASET_PREPARED_PATH, + processor=processor, ) else: path = cfg.pretraining_dataset @@ -123,6 +134,7 @@ def load_tokenized_prepared_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config @@ -180,6 +192,7 @@ def load_tokenized_prepared_datasets( cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")) and not cfg.is_preprocess + and not cfg.skip_prepare_dataset ): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) @@ -423,12 +436,16 @@ def for_d_in_datasets(dataset_configs): dataset=ds, d_base_type=d_base_type, d_prompt_style=d_prompt_style, + processor=processor, ) datasets.append(dataset_wrapper) prompters.append(dataset_prompter) - LOG.info("merging datasets") - dataset = concatenate_datasets(datasets) + if len(datasets) == 1: + dataset = datasets[0] + else: + LOG.info("merging datasets") + dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: @@ -437,9 +454,10 @@ def for_d_in_datasets(dataset_configs): else: LOG.debug("NOT shuffling merged datasets") - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + if not cfg.skip_prepare_dataset: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) - if cfg.local_rank == 0: + if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: @@ -478,9 +496,14 @@ def load_prepare_datasets( cfg, default_dataset_prepared_path, split="train", + processor=None, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path, split=split + tokenizer, + cfg, + default_dataset_prepared_path, + split=split, + processor=processor, ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: @@ -546,6 +569,7 @@ def get_dataset_wrapper( d_base_type, dataset, d_prompt_style=None, + processor=None, ): dataset_wrapper = None dataset_prompter = None @@ -578,7 +602,11 @@ def get_dataset_wrapper( dataset, **ds_kwargs, ) - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + elif cfg.skip_prepare_dataset: + dataset_wrapper = dataset + elif ds_strategy := load( + config_dataset.type, tokenizer, cfg, config_dataset, processor=processor + ): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e183301991..c18af9760f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -28,12 +28,17 @@ AddedToken, AutoConfig, AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoProcessor, AutoTokenizer, AwqConfig, BitsAndBytesConfig, GPTQConfig, + LlavaForConditionalGeneration, + MllamaForConditionalGeneration, PreTrainedModel, PreTrainedTokenizerBase, + ProcessorMixin, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled @@ -80,6 +85,9 @@ def get_module_class_from_name(module, name): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): + if cfg.is_multimodal: + model_config = model_config.text_config + quant_config_exists = ( hasattr(model_config, "quantization_config") and model_config.quantization_config @@ -299,11 +307,31 @@ def load_tokenizer(cfg): return tokenizer +def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): + processor_kwargs: Dict[str, Any] = {} # do we actually need this? + + processor_cls = AutoProcessor + if cfg.processor_type: + processor_cls = getattr(transformers, cfg.processor_type) + + processor = processor_cls.from_pretrained( + cfg.processor_config, + trust_remote_code=cfg.trust_remote_code or False, + tokenizer=tokenizer, + **processor_kwargs, + ) + + return processor + + def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument inference: bool = False, reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -319,12 +347,23 @@ def load_model( plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(cfg) + if cfg.is_multimodal: + text_model_config = model_config.text_config + else: + text_model_config = model_config + # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit if cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if hasattr(model_config, "model_type") and model_config.model_type == "mllama": + if cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama + + patch_mllama() + if hasattr(model_config, "model_type") and model_config.model_type == "btlm": if cfg.flash_attention: from axolotl.monkeypatch.btlm_attn_hijack_flash import ( @@ -461,6 +500,19 @@ def load_model( max_memory = cfg.max_memory device_map = cfg.device_map + AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name + if cfg.is_multimodal: + if model_config.model_type == "llava": + AutoModelLoader = ( # pylint: disable=invalid-name + LlavaForConditionalGeneration + ) + elif model_config.model_type == "mllama": + AutoModelLoader = ( # pylint: disable=invalid-name + MllamaForConditionalGeneration + ) + else: + AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name + if cfg.gpu_memory_limit: gpu_memory_limit = ( str(cfg.gpu_memory_limit) + "GiB" @@ -478,7 +530,7 @@ def load_model( from accelerate import infer_auto_device_map with init_empty_weights(): - model_canvas = AutoModelForCausalLM.from_config( + model_canvas = AutoModelLoader.from_config( model_config, trust_remote_code=cfg.trust_remote_code or False ) model_canvas.tie_weights() @@ -633,6 +685,8 @@ def load_model( quantization_config = ( quantization_config or model_kwargs["quantization_config"] ) + if cfg.is_multimodal: + model_config.text_config = text_model_config model = load_sharded_model_quant( base_model, model_config, @@ -651,7 +705,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, **model_kwargs, @@ -690,13 +746,17 @@ def load_model( and not cfg.trust_remote_code ): if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: + if cfg.is_multimodal: + model_config.text_config = text_model_config model = getattr(transformers, model_type).from_pretrained( base_model, config=model_config, @@ -707,21 +767,23 @@ def load_model( # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(model_config, "max_seq_len") - and model_config.max_seq_len + hasattr(text_model_config, "max_seq_len") + and text_model_config.max_seq_len and cfg.sequence_len > model_config.max_seq_len ): - model_config.max_seq_len = cfg.sequence_len + text_model_config.max_seq_len = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") elif ( - hasattr(model_config, "max_sequence_length") - and model_config.max_sequence_length - and cfg.sequence_len > model_config.max_sequence_length + hasattr(text_model_config, "max_sequence_length") + and text_model_config.max_sequence_length + and cfg.sequence_len > text_model_config.max_sequence_length ): - model_config.max_sequence_length = cfg.sequence_len + text_model_config.max_sequence_length = cfg.sequence_len LOG.warning(f"increasing context length to {cfg.sequence_len}") if cfg.gptq: - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -734,7 +796,9 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( + if cfg.is_multimodal: + model_config.text_config = text_model_config + model = AutoModelLoader.from_pretrained( base_model, config=model_config, trust_remote_code=cfg.trust_remote_code or False, @@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False): from peft import LoraConfig, get_peft_model - lora_target_modules = list(cfg.lora_target_modules or []) + lora_target_modules = cfg.lora_target_modules or [] if cfg.lora_target_linear: linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules = list(set(lora_target_modules + linear_names)) + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) lora_config_kwargs = {} loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits @@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 89ae4e6970..17276dd8ed 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,7 +306,7 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens: + if not cfg.total_num_tokens and not cfg.skip_prepare_dataset: total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() @@ -319,7 +319,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): skip_estimates = cfg.model_config_type == "mamba" - if not skip_estimates and not cfg.total_supervised_tokens: + if ( + not skip_estimates + and not cfg.total_supervised_tokens + and not cfg.skip_prepare_dataset + ): total_supervised_tokens = ( train_dataset.data.column("labels") .to_pandas() @@ -478,13 +482,15 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps +): if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: - trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: - trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 28210b7ae8..20533504ce 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -73,7 +73,7 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -113,7 +113,7 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_templates("phi_35"), + chat_template=chat_templates("phi_35"), message_field_role="role", message_field_content="content", roles={ @@ -171,7 +171,7 @@ def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -227,8 +227,11 @@ class TestSharegptChatTemplateLlama3: def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -277,8 +280,11 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -327,8 +333,11 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") + # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index f18fb39423..50429e3a26 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -34,7 +34,9 @@ def find_sublist(full_list, sub_list): def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -77,7 +79,9 @@ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -118,7 +122,9 @@ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -144,7 +150,9 @@ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -175,7 +183,9 @@ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -194,7 +204,9 @@ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -219,7 +231,9 @@ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -267,7 +281,9 @@ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -298,7 +314,9 @@ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter( + llama3_tokenizer, chat_template=chat_templates("llama3") + ), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -324,7 +342,9 @@ def test_drop_system_message(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with drop_system_message=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + llama3_tokenizer, + chat_template=chat_templates("llama3"), + drop_system_message=True, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -350,7 +370,9 @@ def test_custom_roles(self, llama3_tokenizer): } strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + llama3_tokenizer, + chat_template=chat_templates("llama3"), + roles=custom_roles, ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -402,7 +424,7 @@ def test_message_field_training(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + chat_template=chat_templates("llama3"), message_field_training="train", message_field_training_detail="train_detail", ), From 4ca0a47cfb884f2d4421785982e64220b26f48df Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Oct 2024 08:43:11 -0400 Subject: [PATCH 53/89] add 2.4.1 to base models (#1953) --- .github/workflows/base.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 9101fc2bea..5e8c8fc33d 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -30,6 +30,12 @@ jobs: python_version: "3.11" pytorch: 2.4.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "124" + cuda_version: 12.4.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.4.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 From e8d3da00814ec7773d33edd5643bb885d85686cb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 9 Oct 2024 11:53:56 -0400 Subject: [PATCH 54/89] upgrade pytorch from 2.4.0 => 2.4.1 (#1950) * upgrade pytorch from 2.4.0 => 2.4.1 * update xformers for updated pytorch version * handle xformers version case for torch==2.3.1 --- .github/workflows/base.yml | 2 +- .github/workflows/main.yml | 4 ++-- .github/workflows/nightlies.yml | 4 ++-- .github/workflows/tests-nightly.yml | 4 ++-- .github/workflows/tests.yml | 4 ++-- requirements.txt | 2 +- setup.py | 7 +++++++ 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 5e8c8fc33d..1b24f2c970 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -28,7 +28,7 @@ jobs: cuda_version: 12.4.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "124" cuda_version: 12.4.1 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a972f5f08..c27dbedefa 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,7 +27,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -84,7 +84,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 1d95a0983f..17c76c24e7 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -26,7 +26,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -83,7 +83,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 30ed397cef..8c9e1f49e7 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -91,7 +91,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: nightly_build: "true" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c104e92c27..a798bdd5cd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.0"] + pytorch_version: ["2.3.1", "2.4.1"] timeout-minutes: 20 steps: @@ -94,7 +94,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.0 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/requirements.txt b/requirements.txt index 123a4ee54a..41bfdfbeb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.27 +xformers==0.0.28.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1b64fadaef..e939bc37ee 100644 --- a/setup.py +++ b/setup.py @@ -49,10 +49,17 @@ def parse_requirements(): else: raise ValueError("Invalid version format") + if (major, minor) >= (2, 4): + if patch == 0: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") if (major, minor) >= (2, 3): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") From a560593b1dbac3f3afcbe6bdf975c9c9e5a5afcc Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 10 Oct 2024 03:02:32 +0700 Subject: [PATCH 55/89] fix(log): update perplexity log to clarify from eval split (#1952) [skip ci] --- src/axolotl/utils/callbacks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 73715b06ab..acc2238a4f 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -462,7 +462,7 @@ def evaluate_preds(sources, predictions, references): references=[[r] for r in references], predictions=predictions, ) - scores[metric_name] = score + scores["eval_" + metric_name] = score return scores def predict_with_generate(): From dee77232feb5c7e41216e5586da3ec4407638846 Mon Sep 17 00:00:00 2001 From: aarush gupta Date: Wed, 9 Oct 2024 13:03:16 -0700 Subject: [PATCH 56/89] fix type annotations (#1941) [skip ci] --- src/axolotl/monkeypatch/relora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index e4352cbe3d..9d246cb17f 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio): def reset_optimizer( optimizer: torch.optim.Optimizer, *, - reset_params: list[str], # where str is the key to a torch.nn.Parameter - optimizer_state_keys: list[str], + reset_params: List[str], # where str is the key to a torch.nn.Parameter + optimizer_state_keys: List[str], prune_ratio: float = 0.9, ): pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) From 6d3caadf90a9d4faafe8e167441355d128c66537 Mon Sep 17 00:00:00 2001 From: Boris Feld Date: Wed, 9 Oct 2024 22:03:37 +0200 Subject: [PATCH 57/89] Comet integration (#1939) * Add first version of a Comet integration * Remove debug prints * Add test for Comet Configuration transformation to env variables * Fix last lint warning * Update Readme for Comet logging documentation * Update Comet integration to be optional, update code and tests * Add documentation for Comet configuration * Add missing check --- .isort.cfg | 2 +- README.md | 18 ++- docs/config.qmd | 12 ++ src/axolotl/cli/__init__.py | 3 + src/axolotl/core/trainer_builder.py | 15 ++- src/axolotl/utils/__init__.py | 6 +- src/axolotl/utils/callbacks/__init__.py | 11 +- src/axolotl/utils/callbacks/comet_.py | 43 ++++++++ src/axolotl/utils/comet_.py | 93 ++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 14 +++ tests/test_validation.py | 103 ++++++++++++++++++ 11 files changed, 315 insertions(+), 5 deletions(-) create mode 100644 src/axolotl/utils/callbacks/comet_.py create mode 100644 src/axolotl/utils/comet_.py diff --git a/.isort.cfg b/.isort.cfg index 79067a7c91..e487797321 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] profile=black -known_third_party=wandb +known_third_party=wandb,comet_ml diff --git a/README.md b/README.md index c84f1cb8c9..f6f4e4e806 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb or mlflow +- Log results and optionally checkpoints to wandb, mlflow or Comet - And more! @@ -515,6 +515,22 @@ wandb_name: wandb_log_model: ``` +##### Comet Logging + +Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`. + +- wandb options +```yaml +use_comet: +comet_api_key: +comet_workspace: +comet_project_name: +comet_experiment_key: +comet_mode: +comet_online: +comet_experiment_config: +``` + ##### Special Tokens It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: diff --git a/docs/config.qmd b/docs/config.qmd index e859999787..99a69a0973 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -267,6 +267,18 @@ mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry +# Comet configuration if you're using it +# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. +# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start +use_comet: # Enable or disable Comet integration. +comet_api_key: # API key for Comet. Recommended to set via `comet login`. +comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. +comet_project_name: # Project name in Comet. Defaults to Uncategorized. +comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. +comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. +comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. +comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. + # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a1d84b6a16..db975501a3 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -31,6 +31,7 @@ from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + return cfg diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4893e63dc2..b1ee519dc4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -48,7 +48,7 @@ from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -1111,6 +1111,12 @@ def get_callbacks(self) -> List[TrainerCallback]: callbacks.append( SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_comet and is_comet_available(): + from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback + + callbacks.append( + SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) + ) return callbacks @@ -1179,6 +1185,11 @@ def get_post_trainer_create_callbacks(self, trainer): trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "comet_ml" + ) + callbacks.append(LogPredictionCallback(self.cfg)) if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) @@ -1430,6 +1441,8 @@ def build(self, total_num_steps): report_to.append("mlflow") if self.cfg.use_tensorboard: report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 99dec79f1b..91545009ad 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -1,8 +1,12 @@ """ Basic utils for Axolotl """ -import importlib +import importlib.util def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None + + +def is_comet_available(): + return importlib.util.find_spec("comet_ml") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index acc2238a4f..0bc781fcb4 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -29,7 +29,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -747,6 +747,15 @@ def log_table_from_dataloader(name: str, table_dataloader): artifact_file="PredictionsVsGroundTruth.json", tracking_uri=tracking_uri, ) + elif logger == "comet_ml" and is_comet_available(): + import comet_ml + + experiment = comet_ml.get_running_experiment() + if experiment: + experiment.log_table( + f"{name} - Predictions vs Ground Truth.csv", + pd.DataFrame(table_data), + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py new file mode 100644 index 0000000000..b29f997a86 --- /dev/null +++ b/src/axolotl/utils/callbacks/comet_.py @@ -0,0 +1,43 @@ +"""Comet module for trainer callbacks""" + +import logging +from typing import TYPE_CHECKING + +import comet_ml +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainingArguments + +LOG = logging.getLogger("axolotl.callbacks") + + +class SaveAxolotlConfigtoCometCallback(TrainerCallback): + """Callback to save axolotl config to comet""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + comet_experiment = comet_ml.start(source="axolotl") + comet_experiment.log_other("Created from", "axolotl") + comet_experiment.log_asset( + self.axolotl_config_path, + file_name="axolotl-config", + ) + LOG.info( + "The Axolotl config has been saved to the Comet Experiment under assets." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to Comet: {err}") + return control diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py new file mode 100644 index 0000000000..b4ecc80ad9 --- /dev/null +++ b/src/axolotl/utils/comet_.py @@ -0,0 +1,93 @@ +"""Module for wandb utilities""" + +import logging +import os + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.utils.comet_") + +COMET_ENV_MAPPING_OVERRIDE = { + "comet_mode": "COMET_START_MODE", + "comet_online": "COMET_START_ONLINE", +} +COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = { + "auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS", + "auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE", + "auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS", + "auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD", + "auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS", + "auto_log_co2": "COMET_AUTO_LOG_CO2", + "auto_metric_logging": "COMET_AUTO_LOG_METRICS", + "auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE", + "auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER", + "auto_param_logging": "COMET_AUTO_LOG_PARAMETERS", + "comet_disabled": "COMET_AUTO_LOG_DISABLE", + "display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL", + "distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER", + "log_code": "COMET_AUTO_LOG_CODE", + "log_env_cpu": "COMET_AUTO_LOG_ENV_CPU", + "log_env_details": "COMET_AUTO_LOG_ENV_DETAILS", + "log_env_disk": "COMET_AUTO_LOG_ENV_DISK", + "log_env_gpu": "COMET_AUTO_LOG_ENV_GPU", + "log_env_host": "COMET_AUTO_LOG_ENV_HOST", + "log_env_network": "COMET_AUTO_LOG_ENV_NETWORK", + "log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA", + "log_git_patch": "COMET_AUTO_LOG_GIT_PATCH", + "log_graph": "COMET_AUTO_LOG_GRAPH", + "name": "COMET_START_EXPERIMENT_NAME", + "offline_directory": "COMET_OFFLINE_DIRECTORY", + "parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS", + "tags": "COMET_START_EXPERIMENT_TAGS", +} + + +def python_value_to_environ_value(python_value): + if isinstance(python_value, bool): + if python_value is True: + return "true" + + return "false" + + if isinstance(python_value, int): + return str(python_value) + + if isinstance(python_value, list): # Comet only have one list of string parameter + return ",".join(map(str, python_value)) + + return python_value + + +def setup_comet_env_vars(cfg: DictDefault): + # TODO, we need to convert Axolotl configuration to environment variables + # as Transformers integration are call first and would create an + # Experiment first + + for key in cfg.keys(): + if key.startswith("comet_") and key != "comet_experiment_config": + value = cfg.get(key, "") + + if value is not None and value != "": + env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper()) + final_value = python_value_to_environ_value(value) + os.environ[env_variable_name] = final_value + + if cfg.comet_experiment_config: + for key, value in cfg.comet_experiment_config.items(): + if value is not None and value != "": + config_env_variable_name = ( + COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key) + ) + + if config_env_variable_name is None: + LOG.warning( + f"Unknown Comet Experiment Config name {key}, ignoring it" + ) + continue + + final_value = python_value_to_environ_value(value) + os.environ[config_env_variable_name] = final_value + + # Enable comet if project name is present + if cfg.comet_project_name and len(cfg.comet_project_name) > 0: + cfg.use_comet = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index fced5e639d..76748191bf 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -489,6 +489,19 @@ def check_wandb_run(cls, data): return data +class CometConfig(BaseModel): + """Comet configuration subset""" + + use_comet: Optional[bool] = None + comet_api_key: Optional[str] = None + comet_workspace: Optional[str] = None + comet_project_name: Optional[str] = None + comet_experiment_key: Optional[str] = None + comet_mode: Optional[str] = None + comet_online: Optional[bool] = None + comet_experiment_config: Optional[Dict[str, Any]] = None + + class GradioConfig(BaseModel): """Gradio configuration subset""" @@ -509,6 +522,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + CometConfig, LISAConfig, GradioConfig, RemappedParameters, diff --git a/tests/test_validation.py b/tests/test_validation.py index 35d0e265e7..6e0d0ad2a5 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -9,6 +9,7 @@ import pytest from pydantic import ValidationError +from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault @@ -1329,3 +1330,105 @@ def test_wandb_set_disabled(self, minimal_cfg): os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_DISABLED", None) + + +@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed") +class TestValidationComet(BaseValidation): + """ + Validation test for comet + """ + + def test_comet_sets_env(self, minimal_cfg): + from axolotl.utils.comet_ import setup_comet_env_vars + + comet_config = { + "comet_api_key": "foo", + "comet_workspace": "some_workspace", + "comet_project_name": "some_project", + "comet_experiment_key": "some_experiment_key", + "comet_mode": "get_or_create", + "comet_online": False, + "comet_experiment_config": { + "auto_histogram_activation_logging": False, + "auto_histogram_epoch_rate": 2, + "auto_histogram_gradient_logging": True, + "auto_histogram_tensorboard_logging": False, + "auto_histogram_weight_logging": True, + "auto_log_co2": False, + "auto_metric_logging": True, + "auto_metric_step_rate": 15, + "auto_output_logging": False, + "auto_param_logging": True, + "comet_disabled": False, + "display_summary_level": 2, + "distributed_node_identifier": "some_distributed_node_identifier", + "log_code": True, + "log_env_cpu": False, + "log_env_details": True, + "log_env_disk": False, + "log_env_gpu": True, + "log_env_host": False, + "log_env_network": True, + "log_git_metadata": False, + "log_git_patch": True, + "log_graph": False, + "name": "some_name", + "offline_directory": "some_offline_directory", + "parse_args": True, + "tags": ["tag1", "tag2"], + }, + } + + cfg = DictDefault(comet_config) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_comet_env_vars(new_cfg) + + comet_env = { + key: value for key, value in os.environ.items() if key.startswith("COMET_") + } + + assert ( + len(comet_env) + == len(comet_config) + len(comet_config["comet_experiment_config"]) - 1 + ) + + assert comet_env == { + "COMET_API_KEY": "foo", + "COMET_AUTO_LOG_CLI_ARGUMENTS": "true", + "COMET_AUTO_LOG_CO2": "false", + "COMET_AUTO_LOG_CODE": "true", + "COMET_AUTO_LOG_DISABLE": "false", + "COMET_AUTO_LOG_ENV_CPU": "false", + "COMET_AUTO_LOG_ENV_DETAILS": "true", + "COMET_AUTO_LOG_ENV_DISK": "false", + "COMET_AUTO_LOG_ENV_GPU": "true", + "COMET_AUTO_LOG_ENV_HOST": "false", + "COMET_AUTO_LOG_ENV_NETWORK": "true", + "COMET_AUTO_LOG_GIT_METADATA": "false", + "COMET_AUTO_LOG_GIT_PATCH": "true", + "COMET_AUTO_LOG_GRAPH": "false", + "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false", + "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2", + "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true", + "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false", + "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true", + "COMET_AUTO_LOG_METRIC_STEP_RATE": "15", + "COMET_AUTO_LOG_METRICS": "true", + "COMET_AUTO_LOG_OUTPUT_LOGGER": "false", + "COMET_AUTO_LOG_PARAMETERS": "true", + "COMET_DISPLAY_SUMMARY_LEVEL": "2", + "COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier", + "COMET_EXPERIMENT_KEY": "some_experiment_key", + "COMET_OFFLINE_DIRECTORY": "some_offline_directory", + "COMET_PROJECT_NAME": "some_project", + "COMET_START_EXPERIMENT_NAME": "some_name", + "COMET_START_EXPERIMENT_TAGS": "tag1,tag2", + "COMET_START_MODE": "get_or_create", + "COMET_START_ONLINE": "false", + "COMET_WORKSPACE": "some_workspace", + } + + for key in comet_env.keys(): + os.environ.pop(key, None) From 979534c851ddd4bdd53a8d0162b3ee349774860c Mon Sep 17 00:00:00 2001 From: pandora <128635000+pandora-s-git@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:22:53 +0200 Subject: [PATCH 58/89] add mistral templates (#1927) Co-authored-by: Wing Lian --- src/axolotl/utils/chat_templates.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7468ae8b15..9e1e6ca326 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -5,7 +5,9 @@ CHAT_TEMPLATES = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... + "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... + "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral... "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", From 8159cbd1ab5f0cf3dd5d07237cb7d0d3e40b8a57 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 10 Oct 2024 15:04:17 -0400 Subject: [PATCH 59/89] lm_eval harness post train (#1926) * wip, lm_eval harness post train * include latex parser * add dtype and doc * add validation when doing bench evals * automatically add test dataset when doing benches --- requirements.txt | 6 +++ src/axolotl/cli/train.py | 15 ++++--- src/axolotl/integrations/base.py | 37 ++++++++++++++++ src/axolotl/integrations/lm_eval/README.md | 13 ++++++ src/axolotl/integrations/lm_eval/__init__.py | 42 +++++++++++++++++++ src/axolotl/integrations/lm_eval/args.py | 15 +++++++ .../config/models/input/v0_4_1/__init__.py | 20 +++++++++ 7 files changed, 143 insertions(+), 5 deletions(-) create mode 100644 src/axolotl/integrations/lm_eval/README.md create mode 100644 src/axolotl/integrations/lm_eval/__init__.py create mode 100644 src/axolotl/integrations/lm_eval/args.py diff --git a/requirements.txt b/requirements.txt index 41bfdfbeb4..4323c76ce1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,3 +46,9 @@ gcsfs>=2024.5.0 trl==0.9.6 zstandard==0.22.0 fastcore + +# lm eval harness +lm_eval==0.4.4 +langdetect==1.0.9 +immutabledict==4.2.0 +antlr4-python3-runtime==4.13.2 diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 050f18a054..16d66a82f0 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -3,13 +3,11 @@ """ import logging from pathlib import Path -from typing import Tuple, Union +from typing import Union import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, @@ -20,6 +18,7 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.integrations.base import PluginManager from axolotl.prompt_strategies.sharegpt import ( register_chatml_template, register_llama3_template, @@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return do_train(parsed_cfg, parsed_cli_args) -def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: +def do_train(cfg, cli_args) -> None: print_axolotl_text_art() check_accelerate_default_config() check_user_token() @@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + plugin_manager = PluginManager.get_instance() + + del model + del tokenizer + + plugin_manager.post_train_unload(cfg) if __name__ == "__main__": diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index d26eed90fe..e2bd79bc4d 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -159,6 +159,29 @@ def add_callbacks_post_trainer(self, cfg, trainer): List[callable]: A list of callback functions to be added to the TrainingArgs """ + def post_train(self, cfg, model): + """ + Performs actions after training is complete. + + Parameters: + cfg (dict): The axolotl configuration + model (object): The loaded model. + + Returns: + None + """ + + def post_train_unload(self, cfg): + """ + Performs actions after training is complete and the model is unloaded. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + def load_plugin(plugin_name: str) -> BasePlugin: """ @@ -381,3 +404,17 @@ def add_callbacks_post_trainer(self, cfg, trainer): for plugin in self.plugins: callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) return callbacks + + def post_train_unload(self, cfg): + """ + Calls the post_train_unload method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_train_unload(cfg) diff --git a/src/axolotl/integrations/lm_eval/README.md b/src/axolotl/integrations/lm_eval/README.md new file mode 100644 index 0000000000..3724c49ccf --- /dev/null +++ b/src/axolotl/integrations/lm_eval/README.md @@ -0,0 +1,13 @@ +# LM Eval Harness + +### Usage + +```yaml +plugins: + - axolotl.integrations.lm_eval.LMEvalPlugin + +lm_eval_tasks: + - gsm8k + - hellaswag + - arc_easy +``` diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py new file mode 100644 index 0000000000..f1daa20000 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -0,0 +1,42 @@ +""" +Module for the Plugin for LM Eval Harness +""" +import subprocess # nosec +from datetime import datetime + +from axolotl.integrations.base import BasePlugin + +from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 + + +class LMEvalPlugin(BasePlugin): + """ + Plugin for LM Evaluation Harness integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.lm_eval.LMEvalArgs" + + def post_train_unload(self, cfg): + tasks = ",".join(cfg.lm_eval_tasks) + fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" + dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" + output_path = cfg.output_dir + output_path += "" if cfg.output_dir.endswith("/") else "/" + output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") + subprocess.run( # nosec + [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={cfg.output_dir}{fa2}{dtype}", + "--tasks", + tasks, + "--batch_size", + str(cfg.lm_eval_batch_size), + "--output_path", + output_path, + ], + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py new file mode 100644 index 0000000000..f58e6a6e38 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/args.py @@ -0,0 +1,15 @@ +""" +Module for handling lm eval harness input arguments. +""" +from typing import List, Optional + +from pydantic import BaseModel + + +class LMEvalArgs(BaseModel): + """ + Input args for lm eval harness + """ + + lm_eval_tasks: List[str] = [] + lm_eval_batch_size: Optional[int] = 8 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 76748191bf..47796add6b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -980,6 +980,26 @@ def check_evals(cls, data): "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_test_datasets_bench(cls, data): + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) + data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] return data @model_validator(mode="before") From 2fbc6b0c644424feff91abe3d001fa1c6638e118 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 10 Oct 2024 15:57:37 -0400 Subject: [PATCH 60/89] Axo logo new (#1956) * update axolotl ascii art * spacing for logo * cleanup dithering * cleanup ascii logo a bit --- src/axolotl/cli/__init__.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index db975501a3..fd5ab3e56c 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -55,8 +55,22 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - -def print_axolotl_text_art(suffix=None): +AXOLOTL_LOGO = """ + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ +""" + + +def print_legacy_axolotl_text_art(suffix=None): font = "nancyj" ascii_text = " axolotl" if suffix: @@ -69,6 +83,13 @@ def print_axolotl_text_art(suffix=None): print_dep_versions() +def print_axolotl_text_art( + **kwargs, # pylint: disable=unused-argument +): + if is_main_process(): + print(AXOLOTL_LOGO) + + def print_dep_versions(): packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] max_len = max(len(pkg) for pkg in packages) From e73b8dff8d5fcfb02371916cbebc1350a3a1a9c9 Mon Sep 17 00:00:00 2001 From: Thomas Cleberg <84520378+thomascleberg@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:32:50 -0500 Subject: [PATCH 61/89] Add Support for `revision` Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912) * Add support for `revision` dataset parameter * only use revision on hf hub backed datasets * use revision tied to head * set download to use revision * feat: add config to model validator class * feat: add revision config to RL and tests for it --------- Co-authored-by: Wing Lian Co-authored-by: NanoCode012 --- docs/config.qmd | 1 + .../config/models/input/v0_4_1/__init__.py | 3 + src/axolotl/utils/data/rl.py | 1 + src/axolotl/utils/data/sft.py | 6 +- tests/test_datasets.py | 138 ++++++++++++++++++ 5 files changed, 148 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 99a69a0973..8329f35535 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 47796add6b..1c33b59078 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -125,6 +125,7 @@ class SFTDataset(BaseModel): drop_system_message: Optional[bool] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class UserDefinedDPOType(BaseModel): @@ -146,6 +147,7 @@ class DPODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedDPOType, str]] = None data_files: Optional[List[str]] = None + revision: Optional[str] = None class UserDefinedKTOType(BaseModel): @@ -167,6 +169,7 @@ class KTODataset(BaseModel): type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class RLType(str, Enum): diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index d0324e1ebd..35bd5fcbb7 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg): ds = load_dataset( # pylint: disable=invalid-name ds_cfg["path"], split=ds_cfg["split"], + revision=ds_cfg.get("revision", None), ) split_datasets.insert(i, ds) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 7d6922cbf2..39eb2c4e04 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -242,6 +242,7 @@ def for_d_in_datasets(dataset_configs): name=config_dataset.name, streaming=True, token=use_auth_token, + revision=config_dataset.revision, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -346,6 +347,7 @@ def for_d_in_datasets(dataset_configs): streaming=False, data_files=config_dataset.data_files, token=use_auth_token, + revision=config_dataset.revision, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -380,6 +382,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, + revision=config_dataset.revision, ) elif isinstance(config_dataset.data_files, list): fp = [] @@ -389,6 +392,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=file, + revision=config_dataset.revision, ) ) else: @@ -433,8 +437,8 @@ def for_d_in_datasets(dataset_configs): config_dataset=config_dataset, tokenizer=tokenizer, cfg=cfg, - dataset=ds, d_base_type=d_base_type, + dataset=ds, d_prompt_style=d_prompt_style, processor=processor, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a274b7b894..f8b463a03e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets +from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -267,6 +268,143 @@ def test_load_from_single_json(self): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_dpo(self): + """Verify that processing dpo data from the hub works""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_hub_with_revision(self): + """Verify that processing data from the hub works with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_hub_with_revision_with_dpo(self): + """Verify that processing dpo data from the hub works with a specific revision""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "revision": "ea82cff", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + + def test_load_local_hub_with_revision(self): + """Verify that a local copy of a hub dataset can be loaded with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="d05c1cb", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + "revision": "d05c1cb", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) + if __name__ == "__main__": unittest.main() From 922db77521f37d32ba7a5ab72b56904fed3bcb5c Mon Sep 17 00:00:00 2001 From: Adam Hazell <34248583+awhazell@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:33:06 +0100 Subject: [PATCH 62/89] Add MLFlow run name option in config (#1961) Co-authored-by: Adam Hazell --- docs/config.qmd | 1 + src/axolotl/core/trainer_builder.py | 9 ++++++--- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 8329f35535..b6c0cb852a 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -266,6 +266,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # mlflow configuration if you're using it mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name +mlflow_run_name: # Your run name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry # Comet configuration if you're using it diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1ee519dc4..9c12b6141a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1445,9 +1445,12 @@ def build(self, total_num_steps): report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to - training_arguments_kwargs["run_name"] = ( - self.cfg.wandb_name if self.cfg.use_wandb else None - ) + if self.cfg.use_wandb: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name + elif self.cfg.use_mlflow: + training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name + else: + training_arguments_kwargs["run_name"] = None training_arguments_kwargs["optim"] = ( self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1c33b59078..1a269b7982 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -447,6 +447,7 @@ class MLFlowConfig(BaseModel): use_mlflow: Optional[bool] = None mlflow_tracking_uri: Optional[str] = None mlflow_experiment_name: Optional[str] = None + mlflow_run_name: Optional[str] = None hf_mlflow_log_artifacts: Optional[bool] = None From 76883851d233d3734c19b1979ede7020059ea37d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 11 Oct 2024 13:33:20 -0400 Subject: [PATCH 63/89] add warning that sharegpt will be deprecated (#1957) * add warning that sharegpt will be deprecated * add helper script for chat_templates and document deprecation * Update src/axolotl/prompt_strategies/sharegpt.py Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- README.md | 2 +- scripts/chat_datasets.py | 60 +++++++++++++++++++++++ src/axolotl/prompt_strategies/sharegpt.py | 3 ++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 scripts/chat_datasets.py diff --git a/README.md b/README.md index f6f4e4e806..4ce7a351bb 100644 --- a/README.md +++ b/README.md @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation + # fastchat conversation (deprecation soon, use chat_template) # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - path: ... type: sharegpt diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py new file mode 100644 index 0000000000..5eb5bde1e2 --- /dev/null +++ b/scripts/chat_datasets.py @@ -0,0 +1,60 @@ +""" +helper script to parse chat datasets into a usable yaml +""" +import click +import yaml +from datasets import load_dataset + + +@click.command() +@click.argument("dataset", type=str) +@click.option("--split", type=str, default="train") +def parse_dataset(dataset=None, split="train"): + ds_cfg = {} + ds_cfg["path"] = dataset + ds_cfg["split"] = split + ds_cfg["type"] = "chat_template" + ds_cfg["chat_template"] = "<<>>" + + dataset = load_dataset(dataset, split=split) + features = dataset.features + feature_keys = features.keys() + field_messages = None + for key in ["conversation", "conversations", "messages"]: + if key in feature_keys: + field_messages = key + break + if not field_messages: + raise ValueError( + f'No conversation field found in dataset: {", ".join(feature_keys)}' + ) + ds_cfg["field_messages"] = field_messages + + message_fields = features["conversations"][0].keys() + message_field_role = None + for key in ["from", "role"]: + if key in message_fields: + message_field_role = key + break + if not message_field_role: + raise ValueError( + f'No role field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_role"] = message_field_role + + message_field_content = None + for key in ["content", "text", "value"]: + if key in message_fields: + message_field_content = key + break + if not message_field_content: + raise ValueError( + f'No content field found in messages: {", ".join(message_fields)}' + ) + ds_cfg["message_field_content"] = message_field_content + + print(yaml.dump({"datasets": [ds_cfg]})) + + +if __name__ == "__main__": + parse_dataset() diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 321f19554b..4565c35d5d 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -61,6 +61,9 @@ def build_loader( default_conversation: Optional[str] = None, ): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + LOG.warning( + "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", + ) conversation = ( ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg From df359c8a6e14ecdd2e1eb0049bd8143c32421952 Mon Sep 17 00:00:00 2001 From: Afrizal Hasbi Azizy Date: Sat, 12 Oct 2024 00:34:13 +0700 Subject: [PATCH 64/89] Handle image input as string paths for MMLMs (#1958) * Update mm_chat.py Handle string image (paths) * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/utils/collators/mm_chat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index f49e97f37f..b9b67f8750 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +from PIL import Image from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers.data.data_collator import DataCollatorMixin from transformers.utils import PaddingStrategy @@ -52,7 +53,12 @@ def process_rows(examples, processor, chat_template, max_images, length_only=Fal ) for example in examples ] - images = [example["images"] for example in examples] + images = [ + Image.open(example["images"]) + if isinstance(example["images"], str) + else example["images"] + for example in examples + ] if max_images > 0: images = [img_batch[:max_images] for img_batch in images] From 09bf1ceacc67b46d6bc5abb8cef2b47c9dd84b8c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Oct 2024 18:19:48 -0400 Subject: [PATCH 65/89] update hf deps (#1964) * update hf deps * remove deprecated set_caching_enabled --- requirements.txt | 12 ++++++------ src/axolotl/utils/trainer.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4323c76ce1..2dd3517a7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.13.0 -transformers==4.45.1 -tokenizers>=0.19.1 -bitsandbytes==0.44.0 -accelerate==0.34.2 -datasets==2.21.0 +peft==0.13.2 +transformers==4.45.2 +tokenizers>=0.20.1 +bitsandbytes==0.44.1 +accelerate==1.0.0 +datasets==3.0.1 deepspeed==0.14.4 pydantic==2.6.3 addict diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 17276dd8ed..30b40925f9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,7 +11,7 @@ import torch import torch.cuda from accelerate.logging import get_logger -from datasets import set_caching_enabled +from datasets import disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available @@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True): @contextmanager def disable_datasets_caching(): try: - set_caching_enabled(False) + disable_caching() yield finally: - set_caching_enabled(True) + enable_caching() def add_position_ids(sample): From d20b48a61e8dff5565303166fde5303c811e5491 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Oct 2024 20:53:48 -0400 Subject: [PATCH 66/89] only install torchao for torch versions >= 2.4.0 (#1963) --- requirements.txt | 2 ++ setup.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2dd3517a7a..37ee1e42cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,5 @@ lm_eval==0.4.4 langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 + +torchao==0.5.0 diff --git a/setup.py b/setup.py index e939bc37ee..7d9568dbff 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] + torchao_version = [req for req in _install_requires if "torchao" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS _install_requires.pop(_install_requires.index(xformers_version)) @@ -53,7 +54,8 @@ def parse_requirements(): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") - if (major, minor) >= (2, 3): + elif (major, minor) >= (2, 3): + _install_requires.pop(_install_requires.index(torchao_version)) if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.26.post1") @@ -61,9 +63,11 @@ def parse_requirements(): _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") elif (major, minor) >= (2, 2): + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.25.post1") else: + _install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.23.post1") From 31591bd94cf8fd3e18fc8949385cf405b1ff0dda Mon Sep 17 00:00:00 2001 From: pandora <128635000+pandora-s-git@users.noreply.github.com> Date: Sun, 13 Oct 2024 03:40:39 +0200 Subject: [PATCH 67/89] Fixing Validation - Mistral Templates (#1962) --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1a269b7982..af1570db63 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -187,7 +187,9 @@ class ChatTemplate(str, Enum): alpaca = "alpaca" # pylint: disable=invalid-name chatml = "chatml" # pylint: disable=invalid-name - inst = "inst" # pylint: disable=invalid-name + mistral_v1 = "mistral_v1" # pylint: disable=invalid-name + mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name + mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name From ac128b7b1dde6e6f0ca9a06697cad6fa31c9d5b0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 13 Oct 2024 08:41:13 +0700 Subject: [PATCH 68/89] fix: update eval causal lm metrics to add perplexity (#1951) [skip ci] --- docs/config.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index b6c0cb852a..703d587753 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -315,7 +315,7 @@ max_steps: eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 -eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf] +eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) From 1834cdc3645c003e3db02346912cab19a1eb5ca3 Mon Sep 17 00:00:00 2001 From: Vincent Haines Date: Sat, 12 Oct 2024 21:41:43 -0400 Subject: [PATCH 69/89] Add support for qwen 2.5 chat template (#1934) --- src/axolotl/utils/chat_templates.py | 1 + src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 9e1e6ca326..2443f56f93 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -17,6 +17,7 @@ "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', + "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", } diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index af1570db63..40f4a36abb 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -198,6 +198,7 @@ class ChatTemplate(str, Enum): phi_35 = "phi_35" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name jamba = "jamba" # pylint: disable=invalid-name + qwen_25 = "qwen_25" # pylint: disable=invalid-name class LoftQConfig(BaseModel): From cd2d89f4672e90af45c6a632c75a624b6219c712 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Oct 2024 12:15:18 -0400 Subject: [PATCH 70/89] wip add new proposed message structure (#1904) * wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader --- requirements_env.txt | 315 ++++++++++++++++++ src/axolotl/cli/preprocess.py | 10 +- src/axolotl/core/chat/__init__.py | 0 src/axolotl/core/chat/format/__init__.py | 0 src/axolotl/core/chat/format/chatml.py | 34 ++ src/axolotl/core/chat/format/llama3x.py | 45 +++ src/axolotl/core/chat/format/shared.py | 47 +++ src/axolotl/core/chat/messages.py | 230 +++++++++++++ src/axolotl/core/datasets/__init__.py | 0 src/axolotl/core/datasets/chat.py | 55 +++ .../core/datasets/transforms/__init__.py | 0 .../core/datasets/transforms/chat_builder.py | 150 +++++++++ src/axolotl/prompt_strategies/__init__.py | 7 +- .../prompt_strategies/messages/__init__.py | 34 ++ .../prompt_strategies/messages/chat.py | 84 +++++ src/axolotl/prompt_tokenizers.py | 6 + .../config/models/input/v0_4_1/__init__.py | 2 + src/axolotl/utils/data/sft.py | 22 +- tests/core/chat/__init__.py | 0 tests/core/chat/format/__init__.py | 0 tests/core/chat/test_messages.py | 197 +++++++++++ tests/prompt_strategies/messages/__init__.py | 0 tests/prompt_strategies/messages/test_chat.py | 62 ++++ 23 files changed, 1285 insertions(+), 15 deletions(-) create mode 100644 requirements_env.txt create mode 100644 src/axolotl/core/chat/__init__.py create mode 100644 src/axolotl/core/chat/format/__init__.py create mode 100644 src/axolotl/core/chat/format/chatml.py create mode 100644 src/axolotl/core/chat/format/llama3x.py create mode 100644 src/axolotl/core/chat/format/shared.py create mode 100644 src/axolotl/core/chat/messages.py create mode 100644 src/axolotl/core/datasets/__init__.py create mode 100644 src/axolotl/core/datasets/chat.py create mode 100644 src/axolotl/core/datasets/transforms/__init__.py create mode 100644 src/axolotl/core/datasets/transforms/chat_builder.py create mode 100644 src/axolotl/prompt_strategies/messages/__init__.py create mode 100644 src/axolotl/prompt_strategies/messages/chat.py create mode 100644 tests/core/chat/__init__.py create mode 100644 tests/core/chat/format/__init__.py create mode 100644 tests/core/chat/test_messages.py create mode 100644 tests/prompt_strategies/messages/__init__.py create mode 100644 tests/prompt_strategies/messages/test_chat.py diff --git a/requirements_env.txt b/requirements_env.txt new file mode 100644 index 0000000000..f8acbf73c2 --- /dev/null +++ b/requirements_env.txt @@ -0,0 +1,315 @@ +accelerate==0.34.1 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.9.0 +aiosignal==1.3.1 +aiostream==0.5.2 +alembic==1.13.1 +annotated-types==0.6.0 +annoy==1.17.3 +ansible==6.7.0 +ansible-core==2.13.13 +ansible-vault==2.1.0 +anyio==3.7.1 +appdirs==1.4.4 +art==6.0 +asgiref==3.7.2 +async-timeout==4.0.2 +attrdict==2.0.1 +attrs==22.2.0 +awscli==1.32.75 +-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl +backoff==2.2.1 +base58==2.1.1 +beartype==0.17.2 +bitnet==0.2.1 +bitsandbytes==0.42.0 +bittensor==6.7.0 +black==23.7.0 +blinker==1.7.0 +boto3==1.34.75 +botocore==1.34.75 +cachetools==5.3.3 +cachy==0.1.1 +certifi==2023.7.22 +cffi==1.16.0 +cfgv==3.3.1 +chai-guanaco==1.2.4 +charset-normalizer==3.2.0 +cleo==0.6.8 +click==8.1.7 +cloudpickle==2.0.0 +cohere==4.11.2 +colorama==0.4.4 +coloredlogs==15.0.1 +CoLT5-attention==0.10.20 +contextlib2==21.6.0 +contourpy==1.2.0 +cryptography==41.0.3 +cycler==0.12.1 +cytoolz==0.12.3 +databricks-cli==0.18.0 +dataclasses-json==0.5.7 +datasets==2.11.0 +ddt==1.6.0 +decorator==5.1.1 +deepspeed==0.15.0 +# Editable Git install with no remote (dialogpt==0.1) +-e /Users/wing/Projects/ml/dialogpt/src +dill==0.3.6 +distlib==0.3.6 +docker==7.0.0 +docker-pycreds==0.4.0 +docstring-parser==0.15 +docutils==0.16 +ecdsa==0.18.0 +einops==0.7.0 +einops-exts==0.0.4 +einx==0.1.3 +entrypoints==0.4 +eth-hash==0.6.0 +eth-keys==0.5.0 +eth-typing==4.0.0 +eth-utils==2.3.1 +evaluate==0.4.0 +exceptiongroup==1.1.1 +fastapi==0.109.2 +fastcore==1.5.29 +ffmpy==0.4.0 +filelock==3.12.2 +-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet +fire==0.5.0 +first==2.0.2 +flake8==7.0.0 +Flask==3.0.1 +fonttools==4.47.2 +frozendict==2.4.1 +frozenlist==1.3.3 +fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe +fsspec==2023.6.0 +fuzzywuzzy==0.18.0 +gitdb==4.0.10 +GitPython==3.1.31 +google-pasta==0.2.0 +gradio==4.42.0 +gradio_client==1.3.0 +greenlet==2.0.2 +grpclib==0.4.7 +gunicorn==21.2.0 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +httpcore==0.17.3 +httpx==0.24.1 +huggingface-hub==0.23.4 +humanfriendly==10.0 +hyperframe==6.0.1 +identify==2.5.24 +idna==3.4 +immutables==0.20 +importlib-metadata==6.7.0 +importlib-resources==6.1.1 +inflection==0.5.1 +iniconfig==2.0.0 +itsdangerous==2.1.2 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.3.2 +jsonlines==3.1.0 +jsonschema==2.6.0 +kiwisolver==1.4.5 +langchain==0.0.144 +Levenshtein==0.24.0 +libcst==1.1.0 +liger-kernel==0.0.0 +lion-pytorch==0.1.2 +llama-cpp-python==0.1.36 +llvmlite==0.40.1 +local-attention==1.9.0 +loguru==0.7.0 +Mako==1.3.2 +Markdown==3.5.2 +markdown-it-py==3.0.0 +markdown2==2.4.10 +MarkupSafe==2.1.2 +marshmallow==3.19.0 +marshmallow-enum==1.5.1 +matplotlib==3.8.2 +mccabe==0.7.0 +mdurl==0.1.2 +MEGABYTE-pytorch==0.0.7 +-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit +mlflow==2.10.0 +modal==0.62.77 +more-itertools==10.2.0 +mpmath==1.2.1 +msgpack==1.0.7 +msgpack-numpy-opentensor==0.5.0 +multidict==6.0.4 +multiprocess==0.70.14 +munch==2.5.0 +mypy==1.3.0 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +netaddr==0.10.1 +networkx==3.0rc1 +nh3==0.2.14 +nodeenv==1.8.0 +nomic==2.0.2 +numba==0.57.1 +numexpr==2.8.4 +numpy==1.24.4 +oauthlib==3.2.2 +openai==0.27.4 +openapi==1.1.0 +openapi-schema-pydantic==1.2.4 +optimum==1.8.6 +orjson==3.10.7 +packaging==23.1 +pandas==2.0.0 +parameterized==0.9.0 +password-strength==0.0.3.post2 +pastel==0.1.1 +pathos==0.3.0 +pathspec==0.11.1 +pathtools==0.1.2 +peft==0.11.1 +pendulum==3.0.0 +Pillow==9.5.0 +pip-tools==1.11.0 +platformdirs==3.2.0 +pluggy==1.4.0 +poetry==0.7.1 +pox==0.3.2 +ppft==1.7.6.6 +pre-commit==3.3.2 +prettytable==3.10.0 +prompt-toolkit==3.0.39 +protobuf==3.20.2 +protobuf3-to-dict==0.1.5 +psutil==5.9.5 +psycopg==3.1.18 +PuLP==2.8.0 +py==1.11.0 +py-bip39-bindings==0.1.11 +py-cpuinfo==9.0.0 +py-ed25519-zebra-bindings==1.0.1 +py-sr25519-bindings==0.2.0 +pyarrow==11.0.0 +pyasn1==0.6.0 +pycodestyle==2.11.1 +pycparser==2.21 +pycryptodome==3.20.0 +pydantic==2.5.3 +pydantic_core==2.14.6 +pydub==0.25.1 +pyfiglet==0.8.post1 +pyflakes==3.2.0 +Pygments==2.15.1 +PyJWT==2.8.0 +pylev==1.4.0 +PyNaCl==1.5.0 +pynvml==11.5.0 +pyparsing==2.4.7 +pyrsistent==0.14.11 +pytest==8.0.2 +pytest-asyncio==0.23.4 +python-dateutil==2.8.2 +python-dotenv==1.0.1 +python-Levenshtein==0.24.0 +python-multipart==0.0.9 +pytz==2023.3 +PyYAML==6.0.1 +querystring-parser==1.2.4 +rapidfuzz==3.6.1 +regex==2023.6.3 +requests==2.31.0 +requests-toolbelt==0.8.0 +resolvelib==0.8.1 +responses==0.18.0 +retry==0.9.2 +rich==13.7.0 +rsa==4.7.2 +ruff==0.6.3 +s3transfer==0.10.1 +safetensors==0.4.5 +sagemaker==2.148.0 +scalecodec==1.2.7 +schedulefree==1.2.1 +schema==0.7.5 +scikit-learn==1.4.0 +scipy==1.9.3 +seaborn==0.13.2 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==1.19.1 +setproctitle==1.3.2 +shellingham==1.5.4 +shortuuid==1.0.11 +shtab==1.6.5 +sigtools==4.0.1 +six==1.16.0 +skypilot==0.4.1 +smdebug-rulesconfig==1.0.1 +smmap==5.0.0 +sniffio==1.3.0 +SQLAlchemy==1.4.47 +sqlparse==0.4.4 +starlette==0.36.3 +substrate-interface==1.5.2 +svgwrite==1.4.3 +sympy==1.11.1 +synchronicity==0.6.7 +tabulate==0.9.0 +tblib==1.7.0 +tenacity==8.2.2 +tensor-parallel==2.0.0 +termcolor==2.2.0 +text2art==0.2.0 +threadpoolctl==3.2.0 +tiktoken==0.6.0 +time-machine==2.14.1 +timm==0.9.16 +tokenizers==0.19.1 +tokenmonster==1.1.12 +toml==0.9.6 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.2.0 +torchdata==0.6.1 +torchdiffeq==0.2.3 +TorchFix==0.4.0 +torchtext==0.15.2 +torchvision==0.17.0 +tqdm==4.66.2 +transformers==4.44.2 +trl==0.9.6 +typer==0.12.5 +types-certifi==2021.10.8.3 +types-requests==2.31.0.20240125 +types-setuptools==69.0.0.20240125 +types-toml==0.10.8.7 +typing==3.7.4.3 +typing-inspect==0.8.0 +typing_extensions==4.9.0 +tyro==0.5.18 +tzdata==2023.3 +unique-names-generator==1.0.2 +urllib3==2.2.2 +uvicorn==0.22.0 +vector_quantize_pytorch==1.14.1 +virtualenv==20.23.0 +voyager==2.0.2 +wandb==0.16.2 +watchfiles==0.21.0 +wavedrom==2.0.3.post3 +wcwidth==0.2.6 +websocket-client==1.7.0 +websockets==12.0 +Werkzeug==3.0.1 +wonderwords==2.2.0 +xxhash==3.2.0 +yarl==1.8.2 +zetascale==2.2.7 +zipp==3.15.0 diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e12462c000..aab29e2670 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -27,6 +27,7 @@ register_chatml_template, register_llama3_template, ) +from axolotl.utils.trainer import disable_datasets_caching LOG = logging.getLogger("axolotl.cli.preprocess") @@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": - load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + with disable_datasets_caching(): + if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": + load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.download: model_name = parsed_cfg.base_model diff --git a/src/axolotl/core/chat/__init__.py b/src/axolotl/core/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/chat/format/__init__.py b/src/axolotl/core/chat/format/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py new file mode 100644 index 0000000000..315d101a86 --- /dev/null +++ b/src/axolotl/core/chat/format/chatml.py @@ -0,0 +1,34 @@ +""" +ChatML transformation functions for MessageContents +""" +from typing import Optional + +from ..messages import MessageContents, Messages +from .shared import wrap_tools + + +def format_message( + message: Messages, + message_index: Optional[int] = None, # pylint: disable=unused-argument +) -> Messages: + if message.is_chat_formatted: + return message + + # prepend the role prefix within a MessageContents to message.content + message.content.insert( + 0, + MessageContents( + type="text", + value=f"<|im_start|>{message.role}\n", + weight=0, + ), + ) + message.content.append( + MessageContents(type="text", value="<|im_end|>", weight=message.weight) + ) + message.content.append(MessageContents(type="text", value="\n", weight=0)) + + message = wrap_tools(message) + + message.is_chat_formatted = True + return message diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py new file mode 100644 index 0000000000..17fa7aa8d4 --- /dev/null +++ b/src/axolotl/core/chat/format/llama3x.py @@ -0,0 +1,45 @@ +""" +Llama 3.x chat formatting functions for MessageContents +""" +from typing import Optional + +from ..messages import MessageContents, Messages +from .shared import wrap_tools + + +def format_message(message: Messages, message_index: Optional[int] = None) -> Messages: + if message.is_chat_formatted: + return message + + message_role = message.role + if message.role == "tool": + message_role = "ipython" + + # prepend the role prefix within a MessageContents to message.content + message.content.insert( + 0, + MessageContents( + type="text", + value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n", + weight=0, + ), + ) + + message.content.append( + MessageContents(type="text", value="<|eot_id|>", weight=message.weight) + ) + + message = wrap_tools(message) + + if message_index == 0: + message.content.insert( + 0, + MessageContents( + type="text", + value="<|begin_of_text|>", + weight=0, + ), + ) + + message.is_chat_formatted = True + return message diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py new file mode 100644 index 0000000000..9efa2353db --- /dev/null +++ b/src/axolotl/core/chat/format/shared.py @@ -0,0 +1,47 @@ +""" +shared functions for format transforms +""" +from axolotl.core.chat.messages import MessageContents, Messages + + +def wrap_tools(message: Messages): + # loop over message.content by index to find tool calls, we need to wrap each with tags, + # so be wary of indexing issues when changing the list while iterating. + # iterate over the range in reverse order to avoid index shifting + for i in range(len(message.content) - 1, -1, -1): + if message.content[i].type == "tool_call": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool call content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + elif message.content[i].type == "tool_response": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool response content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + + return message diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py new file mode 100644 index 0000000000..c879bf477b --- /dev/null +++ b/src/axolotl/core/chat/messages.py @@ -0,0 +1,230 @@ +""" +internal message representations of chat messages +""" +import json +from enum import Enum +from typing import Any, Callable, List, Optional, Union + +from pydantic import BaseModel +from transformers import PreTrainedTokenizer + + +class MessageRoles(str, Enum): + """ + Message roles for the system, user, assistant, and tools + """ + + system = "system" # pylint: disable=invalid-name + user = "user" # pylint: disable=invalid-name + assistant = "assistant" # pylint: disable=invalid-name + tool = "tool" # pylint: disable=invalid-name + ipython = ( # pylint: disable=invalid-name + # for responses from builtin tools + "ipython" + ) + + +class MessageContentTypes(str, Enum): + """ + Message content types for text, image, audio, tool calls, and tool responses + """ + + special_token = "special_token" # pylint: disable=invalid-name # nosec B105 + text = "text" # pylint: disable=invalid-name + image = "image" # pylint: disable=invalid-name + audio = "audio" # pylint: disable=invalid-name + tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant + tool_response = "tool_response" # pylint: disable=invalid-name + + +class SpecialToken(str, Enum): + """ + Special tokens for beginning of string and end of string + """ + + bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 + eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 + + +class ToolCallFunction(BaseModel): + """ + Tool call function with name and arguments + """ + + name: str + arguments: dict[str, str] + + +class Tool(BaseModel): + """ + Tool with description, function, and parameters + """ + + description: str + function: ToolCallFunction + parameters: dict[str, str] # .properties + + +class ToolCallContents(BaseModel): + """ + Tool call contents with name, arguments, and optional id + """ + + name: str + arguments: dict[str, Union[str, int]] + id: Optional[str] = None # pylint: disable=invalid-name + + def __str__(self) -> str: + data = {"name": self.name, "arguments": self.arguments} + if self.id is not None: + data["id"] = self.id + return json.dumps(data) + + +class ToolResponseContents(BaseModel): + """ + Tool response contents with name, content, and optional id + """ + + name: str + content: Union[str, dict[str, Union[str, int, float]]] + id: Optional[str] = None # pylint: disable=invalid-name + + def __str__(self) -> str: + data = {"name": self.name, "content": self.content} + if self.id is not None: + data["id"] = self.id + return json.dumps(data) + + +class MessageContents(BaseModel): + """ + Message contents with type, value, metadata, weight, newline, and end of contents + """ + + type: Union[str, MessageContentTypes] + value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken] + meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata + weight: Optional[Union[int, float]] = None + has_newline: bool = False + eoc: bool = False # end of contents + + def __str__(self) -> str: + str_val = str(self.value) + if self.has_newline and not str_val.endswith("\n"): + str_val += "\n" + return str_val + + +class Messages(BaseModel): + """ + Messages with role, content, metadata, weight, and chat formatting + """ + + role: Union[MessageRoles, str] # allows for arbitrary roles + content: List["MessageContents"] + meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata + weight: Optional[Union[int, float]] = None + is_chat_formatted: bool = False + + def __str__(self) -> str: + return "".join(str(c) for c in self.content) + + def tokenized( + self, tokenizer: PreTrainedTokenizer, ignore_index=-100 + ) -> dict[str, List[int]]: + # iterate over the contents, tokenizing the concatenated string values up to the current MessageContents + # returns a dictionary mapping w input_ids, attention_mask, and labels + input_ids: List[int] = [] + labels: List[int] = [] + pending_input_ids: List[int] = [] + pending_weight = self.weight + running_content = "" + for _, msg_content in enumerate(self.content): + # TODO also handle non-text content types + if msg_content.type in [ + MessageContentTypes.text.value, + MessageContentTypes.tool_call.value, + MessageContentTypes.tool_response.value, + ]: + running_content += str(msg_content) + tok_results = tokenizer(running_content, add_special_tokens=False) + tok_input_ids = tok_results["input_ids"] + if pending_input_ids: + new_pending_inputs = tok_input_ids[ + len(input_ids) : len(input_ids) + len(pending_input_ids) + ] + if new_pending_inputs != pending_input_ids: + # logging.warning("tokenization mismatch from concatenation.") + pending_input_ids = new_pending_inputs + input_ids.extend(pending_input_ids) + if pending_weight: + labels.extend(pending_input_ids) + else: + labels.extend([ignore_index] * len(pending_input_ids)) + pending_input_ids = tok_results["input_ids"][len(input_ids) :] + pending_weight = self.weight and msg_content.weight not in [0, 0.0] + input_ids.extend(pending_input_ids) + if pending_weight: + labels.extend(pending_input_ids) + else: + labels.extend([ignore_index] * len(pending_input_ids)) + attention_mask = [1] * len(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +class Chats(BaseModel): + """ + top level data structure for chat conversations + """ + + conversation: List[Messages] + + def __str__(self) -> str: + return "".join(str(c) for c in self.conversation) + + def tokenized( + self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100 + ) -> dict[str, List[int]]: + input_ids = [] + attention_mask = [] + labels = [] + for msg in self.conversation: + msg_results = msg.tokenized(tokenizer, ignore_index) + input_ids.extend(msg_results["input_ids"]) + attention_mask.extend(msg_results["attention_mask"]) + labels.extend(msg_results["labels"]) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +class ChatFormattedChats(Chats): + """ + Chat formatted chats with formatter and optional train on inputs + """ + + formatter: Callable # [[Union[dict, Chats]], Chats] + train_on_inputs: bool = False + + def model_post_init(self, __context): + for i, msg in enumerate(self.conversation): + self.conversation[i] = self.formatter(msg, message_index=i) + if self.train_on_inputs: + self.conversation[i].weight = 1 + + +class PreferenceChats(BaseModel): + """ + representation for preference data for chat + """ + + prompt: List[Messages] + chosen: Messages + rejected: Messages diff --git a/src/axolotl/core/datasets/__init__.py b/src/axolotl/core/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py new file mode 100644 index 0000000000..e74c247d2c --- /dev/null +++ b/src/axolotl/core/datasets/chat.py @@ -0,0 +1,55 @@ +""" +chat dataset module +""" +import os +from typing import Callable, Optional, Union + +from datasets import Dataset +from transformers import PreTrainedTokenizer + +from axolotl.core.chat.messages import ChatFormattedChats + + +class TokenizedChatDataset(Dataset): + """ + Tokenized chat dataset + """ + + def __init__( + self, + data: Dataset, + model_transform: Union[PreTrainedTokenizer, Callable], + *args, + message_transform: Optional[Callable] = None, + formatter=None, + process_count: Optional[int] = None, + keep_in_memory: Optional[bool] = False, + **kwargs, + ): + def map_fn(ex): + if message_transform is not None: + ex = message_transform(ex) + if formatter is not None: + ex = ChatFormattedChats( + formatter=formatter, + **ex, + ) + else: + ex = ChatFormattedChats( + **ex, + ) + return ex.tokenized(model_transform) + + process_or_cpu_count: int = ( + process_count or os.cpu_count() # type: ignore[assignment] + ) + num_proc = min(64, process_or_cpu_count) + features = data.features.keys() + tokenized_data = data.map( + map_fn, + num_proc=num_proc, + keep_in_memory=keep_in_memory, + remove_columns=features, + desc="Tokenizing Chats", + ) + super().__init__(tokenized_data.data, *args, **kwargs) diff --git a/src/axolotl/core/datasets/transforms/__init__.py b/src/axolotl/core/datasets/transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py new file mode 100644 index 0000000000..98d5f171a7 --- /dev/null +++ b/src/axolotl/core/datasets/transforms/chat_builder.py @@ -0,0 +1,150 @@ +""" +This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. +""" +from typing import Any, Mapping, Union + + +def chat_message_transform_builder( # pylint: disable=dangerous-default-value + train_on_inputs=False, + conversations_field: str = "conversations", + message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" + message_field_content: Union[str, list[str]] = [ + "value", + "text", + "content", + ], # commonly "content" + message_field_training: Union[str, list[str]] = [ + "train", + "weight", + ], # commonly "weight" +): + """Builds a transform that takes a row from the dataset and converts it to a Chat + + Args: + train_on_inputs (bool, optional): + If True, the transform will train on the inputs. If False, the transform will train on the targets. + Defaults to False. + conversations_field (str, optional): + The field name of the conversations. Defaults to "conversations". + message_field_role (str | list[str], optional): + The field name of the role. Defaults to "role". + message_field_content (str | list[str], optional): + The field name of the message content. Defaults to "content". + message_field_training (str | list[str], optional): + The field name of the train/weight. Defaults to "weight". + + Returns: + Callable: + A function that takes a list of conversations and returns a list of messages. + """ + + message_field_role = ( + [message_field_role] + if isinstance(message_field_role, str) + else message_field_role + ) + message_field_content = ( + [message_field_content] + if isinstance(message_field_content, str) + else message_field_content + ) + message_weight_fields = ( + [message_field_training] + if isinstance(message_field_training, str) + else message_field_training + ) + + role_value_mappings = { + "system": "system", + "user": "user", + "human": "user", + "assistant": "assistant", + "gpt": "assistant", + "tool": "tool", + "ipython": "ipython", + } + if train_on_inputs: + role_default_weights_mappings = { + "system": 1, + "user": 1, + "assistant": 1, + "tool": 1, + "ipython": 1, + } + else: + role_default_weights_mappings = { + "system": 0, + "user": 0, + "assistant": 1, + "tool": 0, + "ipython": 0, + } + + def transform_builder(sample: Mapping[str, Any]): + if conversations_field not in sample: + raise ValueError(f"Field '{conversations_field}' not found in sample.") + # if none of the role fields are in the message, raise an error + if not any( + role in sample[conversations_field][0] for role in message_field_role + ): + raise ValueError("No role field found in message.") + role_field = next( + role + for role in message_field_role + if role in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_content + ): + raise ValueError("No message_content field found in message.") + message_content_field = next( + field + for field in message_field_content + if field in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_training + ): + message_weight_field = None + else: + message_weight_field = next( + field + for field in message_weight_fields + if field in sample[conversations_field][0] + ) + + messages = [] + for message in sample[conversations_field]: + role = role_value_mappings[message[role_field]] + weight = ( + int(message[message_weight_field]) + if message_weight_field + else role_default_weights_mappings[role] + ) + + # TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents + if isinstance(message[message_content_field], str): + messages.append( + { + "role": role, + "content": [ + { + "type": "text", + "value": message[message_content_field], + } + ], + "weight": weight, + } + ) + else: + messages.append( + { + "role": role, + "content": message[message_content_field], + "weight": weight, + } + ) + + return {"conversation": messages} + + return transform_builder diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 66cd5deeb9..74da20c5e1 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -11,6 +11,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): try: + if strategy == "messages": + from .messages import load as messages_load + + return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" if strategy.split(".")[-1].startswith("load_"): load_fn = strategy.split(".")[-1] @@ -31,4 +35,5 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return None except Exception as exc: # pylint: disable=broad-exception-caught LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") - return None + raise exc + return None diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py new file mode 100644 index 0000000000..d014d93a6b --- /dev/null +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -0,0 +1,34 @@ +"""Module to load message prompt strategies.""" + +import importlib +import inspect +import logging + +LOG = logging.getLogger("axolotl.prompt_strategies.messages") + + +def load(tokenizer, cfg, ds_cfg, processor=None): + try: + strategy = ds_cfg.get("input_transform", "chat") + # pylint: disable=duplicate-code + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module( + f".{strategy}", "axolotl.prompt_strategies.messages" + ) + func = getattr(mod, load_fn) + load_kwargs = {} + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg + if "processor" in sig.parameters: + load_kwargs["processor"] = processor + return func(tokenizer, cfg, **load_kwargs) + except ModuleNotFoundError: + return None + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") + raise exc + return None diff --git a/src/axolotl/prompt_strategies/messages/chat.py b/src/axolotl/prompt_strategies/messages/chat.py new file mode 100644 index 0000000000..35d7649026 --- /dev/null +++ b/src/axolotl/prompt_strategies/messages/chat.py @@ -0,0 +1,84 @@ +""" +Chat dataset wrapping strategy for new internal messages representations +""" +from typing import Any, Callable, Dict, Optional + +from axolotl.core.datasets.chat import TokenizedChatDataset +from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder +from axolotl.prompt_tokenizers import DatasetWrappingStrategy + + +class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): + """ + Chat dataset wrapping strategy for new internal messages representations + """ + + def __init__( + self, + processor, + message_transform=None, + formatter=None, + **kwargs, # pylint: disable=unused-argument + ): + """ + :param processor: tokenizer or image processor + :param kwargs: + """ + self.processor = processor + self.dataset = None + self.message_transform = message_transform + self.formatter = formatter + + def wrap_dataset( + self, + dataset, + process_count: Optional[int] = None, + keep_in_memory: Optional[bool] = False, + **kwargs, # pylint: disable=unused-argument + ): + self.dataset = TokenizedChatDataset( + dataset, + message_transform=self.message_transform, + model_transform=self.processor, + formatter=self.formatter, + process_count=process_count, + keep_in_memory=keep_in_memory, + ) + return self.dataset + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + ds_cfg = ds_cfg or {} + + field_messages = ds_cfg.get("field_messages") + message_field_role = ds_cfg.get("message_field_role") + message_field_content = ds_cfg.get("message_field_content") + message_field_training = ds_cfg.get("message_field_training") + + builder_kwargs = {} + if field_messages: + builder_kwargs["conversations_field"] = field_messages + if message_field_role: + builder_kwargs["message_field_role"] = message_field_role + if message_field_content: + builder_kwargs["message_field_content"] = message_field_content + if message_field_training: + builder_kwargs["message_field_training"] = message_field_training + + chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml")) + format_message = ( + lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment + ) + if chat_template == "chatml": + from axolotl.core.chat.format.chatml import format_message # noqa F811 + if chat_template.startswith("llama3"): + from axolotl.core.chat.format.llama3x import format_message # noqa F811 + message_transform: Callable = chat_message_transform_builder( + train_on_inputs=ds_cfg.get("train_on_inputs", False), + **builder_kwargs, + ) + strategy = ChatMessageDatasetWrappingStrategy( + tokenizer, message_transform=message_transform, formatter=format_message + ) + + return strategy diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 11dd084a85..51d497a23c 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -30,6 +30,12 @@ class InvalidDataException(Exception): """ +class DatasetWrappingStrategy(abc.ABC): + """ + Abstract class for wrapping datasets for Chat Messages + """ + + class PromptTokenizingStrategy(abc.ABC): """ Abstract class for tokenizing strategies diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 40f4a36abb..3304c62f28 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -102,10 +102,12 @@ class SFTDataset(BaseModel): path: Optional[str] = None split: Optional[str] = None type: Optional[Union[str, UserDefinedPrompterType]] = None + input_transform: Optional[str] = None shards: Optional[int] = None conversation: Optional[str] = None chat_template: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None + input_format: Optional[str] = None name: Optional[str] = None ds_type: Optional[str] = None train_on_split: Optional[str] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 39eb2c4e04..163059c2b8 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -23,6 +23,7 @@ AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, AlpacaReflectionPTStrategy, + DatasetWrappingStrategy, GPTeacherPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, @@ -573,7 +574,7 @@ def get_dataset_wrapper( d_base_type, dataset, d_prompt_style=None, - processor=None, + processor=None, # pylint: disable=unused-argument ): dataset_wrapper = None dataset_prompter = None @@ -608,15 +609,16 @@ def get_dataset_wrapper( ) elif cfg.skip_prepare_dataset: dataset_wrapper = dataset - elif ds_strategy := load( - config_dataset.type, tokenizer, cfg, config_dataset, processor=processor - ): - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, - dataset, - **ds_kwargs, - ) + elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + if isinstance(ds_strategy, DatasetWrappingStrategy): + dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) + else: + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) ds_strategy = AlpacaPromptTokenizingStrategy( diff --git a/tests/core/chat/__init__.py b/tests/core/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/chat/format/__init__.py b/tests/core/chat/format/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py new file mode 100644 index 0000000000..b3be56c590 --- /dev/null +++ b/tests/core/chat/test_messages.py @@ -0,0 +1,197 @@ +""" +Tests for the chat messages module +""" +import unittest + +import pytest +from transformers import AddedToken, AutoTokenizer + +from axolotl.core.chat.format.chatml import format_message +from axolotl.core.chat.messages import ChatFormattedChats, Chats + + +@pytest.fixture(scope="session", name="llama_tokenizer") +def llama_tokenizer_fixture(): + return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B") + + +@pytest.fixture(scope="session", name="chatml_tokenizer") +def llama_tokenizer_w_chatml(llama_tokenizer): + llama_tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + llama_tokenizer.add_tokens( + [ + AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), + ] + ) + + return llama_tokenizer + + +@pytest.fixture(scope="session", name="chat_msgs") +def chat_msgs_fixture(): + return { + "conversation": [ + { + "role": "system", + "content": [ + {"type": "text", "value": "You are a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "value": "What is today's stock price of Apple?"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_call", + "value": { + "name": "get_date", + "arguments": {}, + }, + }, + { + "type": "tool_call", + "value": { + "name": "get_stock_price", + "arguments": {"symbol": "AAPL"}, + }, + }, + ], + "weight": 1, + }, + { + "role": "tool", + "content": [ + { + "type": "tool_response", + "value": { + "name": "get_date", + "content": {"date": "2024-09-09"}, + }, + }, + { + "type": "tool_response", + "value": { + "name": "get_stock_price", + "content": {"symbol": "AAPL", "price": 123.45}, + }, + }, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "value": "The stock price of Apple is $123.45.\n", + "weight": 0, + }, + { + "type": "text", + "value": "The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.", + }, + { + "type": "text", + "value": "The stock price of Apple on September 9, 2024 is $123.45.", + }, + ], + "weight": 1, + }, + ] + } + + +class TestMessagesCase: + """ + Test cases for the chat messages module + """ + + def test_tool_call_stringify(self, chat_msgs): + chat_msgs_as_obj = Chats(**chat_msgs) + assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str( + chat_msgs_as_obj.conversation[2].content[1].value + ) + + def test_chatml_formatted_wrapper(self, chat_msgs): + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + target_chatml = """<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What is today's stock price of Apple?<|im_end|> +<|im_start|>assistant + +{"name": "get_date", "arguments": {}} + + +{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}} + +<|im_end|> +<|im_start|>tool + +{"name": "get_date", "content": {"date": "2024-09-09"}} + + +{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}} + +<|im_end|> +<|im_start|>assistant +The stock price of Apple is $123.45. +The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n""" + assert target_chatml == str(chat_msg_formatted) + + def test_chatml_formatting_tool_call(self, chat_msgs): + chat_msgs_as_obj = Chats(**chat_msgs) + target_chatml_turn2 = """<|im_start|>assistant\n\n{"name": "get_date", "arguments": {}}\n\n\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n\n<|im_end|>\n""" + assert target_chatml_turn2 == str( + format_message(chat_msgs_as_obj.conversation[2]) + ) + + def test_train_labels(self, chatml_tokenizer, chat_msgs): + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer) + # fmt: off + target_labels = [ + -100, -100, -100, # role + 27, 14506, 13735, 397, 5018, 609, 794, + 330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524, + 14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794, + 330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314, + 794, 330, 84016, 43, 96742, 524, 14506, 13735, 397, + 128256, # <|im_end|> + -100 # trailing newline + ] + # fmt: on + assert tokenized["labels"] == target_labels + + def test_train_labels_2(self, chatml_tokenizer, chat_msgs): + # also test if indivudal contents are set not to train + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer) + # fmt: off + target_labels = [ + -100, -100, -100, # role + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response + 27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430, + 315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457, + 5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315, + 8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400, + 4513, 13, 1774, 13, + 128256, # <|im_end|> + -100, # trailing newline + ] + # fmt: on + assert tokenized["labels"] == target_labels + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/messages/__init__.py b/tests/prompt_strategies/messages/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py new file mode 100644 index 0000000000..96c4b6cbbf --- /dev/null +++ b/tests/prompt_strategies/messages/test_chat.py @@ -0,0 +1,62 @@ +""" +tests for chat_template prompt strategy +""" +# pylint: disable=duplicate-code +import logging +import unittest + +from axolotl.prompt_strategies.messages.chat import load +from axolotl.utils.dict import DictDefault + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestMessagesChatLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy. + """ + + def test_llama3_load(self, llama3_tokenizer, assistant_dataset): + LOG.info("Loading llama-3 tokenizer with assistant dataset") + strategy = load( + llama3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "llama3", + "message_field_role": "role", + "message_field_content": "content", + "field_messages": "messages", + } + ), + ) + res = strategy.wrap_dataset(assistant_dataset) + input_ids = res[0]["input_ids"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + +if __name__ == "__main__": + unittest.main() From 68b1369de9cc8b77931bc4489899216f40fdb93f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Oct 2024 15:11:13 -0400 Subject: [PATCH 71/89] Reward model (#1879) --- examples/gemma2/reward-model.yaml | 63 +++++++++++++ src/axolotl/core/trainer_builder.py | 65 ++++++++++---- .../prompt_strategies/bradley_terry/README.md | 10 +++ .../bradley_terry/__init__.py | 35 ++++++++ .../bradley_terry/chat_template.py | 88 +++++++++++++++++++ .../prompt_strategies/bradley_terry/llama3.py | 27 ++++++ .../prompt_strategies/chat_template.py | 1 + src/axolotl/train.py | 3 +- .../config/models/input/v0_4_1/__init__.py | 12 +++ src/axolotl/utils/data/sft.py | 18 +++- src/axolotl/utils/trainer.py | 7 +- tests/e2e/test_reward_model_llama.py | 74 ++++++++++++++++ 12 files changed, 382 insertions(+), 21 deletions(-) create mode 100644 examples/gemma2/reward-model.yaml create mode 100644 src/axolotl/prompt_strategies/bradley_terry/README.md create mode 100644 src/axolotl/prompt_strategies/bradley_terry/__init__.py create mode 100644 src/axolotl/prompt_strategies/bradley_terry/chat_template.py create mode 100644 src/axolotl/prompt_strategies/bradley_terry/llama3.py create mode 100644 tests/e2e/test_reward_model_llama.py diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml new file mode 100644 index 0000000000..c1f993c3ae --- /dev/null +++ b/examples/gemma2/reward-model.yaml @@ -0,0 +1,63 @@ +base_model: google/gemma-2-2b +model_type: AutoModelForSequenceClassification +tokenizer_type: AutoTokenizer + +load_in_8bit: false +load_in_4bit: false +strict: false + +reward_model: true +chat_template: gemma +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +remove_unused_columns: false + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9c12b6141a..599144bd34 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -43,8 +43,10 @@ KTOTrainer, ORPOConfig, ORPOTrainer, + RewardConfig, + RewardTrainer, ) -from trl.trainer.utils import pad_to_length +from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -301,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): ) +@dataclass +class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): + """ + Reward config for Reward training + """ + + class SchedulerMixin(Trainer): """ Mixin class for scheduler setup in CausalTrainer. @@ -398,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer): def __init__( self, *_args, - num_epochs=1, bench_data_collator=None, eval_data_collator=None, **kwargs, ): - self.num_epochs = num_epochs self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator super().__init__(*_args, **kwargs) @@ -1039,6 +1046,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1214,6 +1229,8 @@ def _get_trainer_cls(self): return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer + if self.cfg.reward_model: + return AxolotlRewardTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -1553,6 +1570,9 @@ def build(self, total_num_steps): trainer_kwargs = {} + if self.cfg.reward_model: + trainer_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.optimizer in [ "optimi_adamw", "ao_adamw_4bit", @@ -1596,10 +1616,13 @@ def build(self, total_num_steps): "accelerator_config" ] = self.cfg.accelerator_config - training_args = ( - AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg - **training_arguments_kwargs, - ) + training_args_cls = ( + AxolotlTrainingArguments + if not self.cfg.reward_model + else AxolotlRewardConfig + ) + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg + **training_arguments_kwargs, ) training_args = self.hook_post_create_training_args(training_args) @@ -1621,10 +1644,24 @@ def build(self, total_num_steps): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = 64 + if self.cfg.reward_model: + data_collator_kwargs["max_length"] = self.cfg.sequence_len + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls ) + if eval_data_collator := self.build_collator( + training_args, is_eval=True, **data_collator_kwargs + ): + if not self.cfg.reward_model: + trainer_kwargs["eval_data_collator"] = eval_data_collator + if not self.cfg.reward_model: + trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ) trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, @@ -1632,16 +1669,7 @@ def build(self, total_num_steps): args=training_args, tokenizer=self.tokenizer, data_collator=self.build_collator(training_args, **data_collator_kwargs), - eval_data_collator=self.build_collator( - training_args, is_eval=True, **data_collator_kwargs - ), - bench_data_collator=transformers.DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), callbacks=self.get_callbacks(), - num_epochs=self.cfg.num_epochs, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) @@ -1675,9 +1703,12 @@ def build_collator( V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + RewardDataCollatorWithPadding, ] ] - if use_batch_sampler_collator: + if self.cfg.reward_model: + collator = RewardDataCollatorWithPadding + elif use_batch_sampler_collator: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( diff --git a/src/axolotl/prompt_strategies/bradley_terry/README.md b/src/axolotl/prompt_strategies/bradley_terry/README.md new file mode 100644 index 0000000000..39cd16137c --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/README.md @@ -0,0 +1,10 @@ +### example yaml + +```yaml +chat_template: gemma +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template +val_set_size: 0.0 +output_dir: ./outputs/out +``` diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py new file mode 100644 index 0000000000..849d84e458 --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -0,0 +1,35 @@ +"""Module to load prompt strategies.""" + +import importlib +import inspect +import logging + +from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig + +LOG = logging.getLogger("axolotl.prompt_strategies") + + +def load(strategy, tokenizer, cfg, ds_cfg): + # pylint: disable=duplicate-code + try: + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module( + f".{strategy}", "axolotl.prompt_strategies.bradley_terry" + ) + func = getattr(mod, load_fn) + load_kwargs = {} + if strategy == "user_defined": + load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) + else: + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg + return func(tokenizer, cfg, **load_kwargs) + except ModuleNotFoundError: + return None + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") + return None diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py new file mode 100644 index 0000000000..ccda0a4bde --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -0,0 +1,88 @@ +""" +Bradley-Terry model with chat template prompt strategy. +""" + +from typing import Any, Dict, Optional + +from axolotl.prompt_strategies.chat_template import ( + ChatTemplatePrompter, + ChatTemplateStrategy, +) +from axolotl.utils.chat_templates import chat_templates + + +class BTChatTemplateStrategy(ChatTemplateStrategy): + """ + Bradley-Terry reward model pairwise chat template prompt strategy. + """ + + def tokenize_prompt(self, prompt): + """ + + :param prompt: the actual row of data from the underlying dataset + :return: + """ + + self.messages = "chosen_messages" + # pylint: disable=duplicate-code + prompt[self.messages] = [] + if prompt["system"]: + prompt[self.messages].append({"from": "system", "value": prompt["system"]}) + prompt[self.messages].append({"from": "user", "value": prompt["input"]}) + prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) + chosen_tokenized = super().tokenize_prompt(prompt) + + self.messages = "rejected_messages" + # pylint: disable=duplicate-code + prompt[self.messages] = [] + if prompt["system"]: + prompt[self.messages].append({"from": "system", "value": prompt["system"]}) + prompt[self.messages].append({"from": "user", "value": prompt["input"]}) + prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) + rejected_tokenized = super().tokenize_prompt(prompt) + + return { + "input_ids_chosen": chosen_tokenized["input_ids"], + "attention_mask_chosen": chosen_tokenized["attention_mask"], + "labels_chosen": 1.0, + "input_ids_rejected": rejected_tokenized["input_ids"], + "attention_mask_rejected": rejected_tokenized["attention_mask"], + "labels_rejected": 0.0, + } + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + ds_cfg = ds_cfg or {} + + prompter_params = { + "tokenizer": tokenizer, + "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "message_field_role": ds_cfg.get("message_field_role", "from"), + "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training_detail": ds_cfg.get( + "message_field_training_detail", "train_detail" + ), + "roles": ds_cfg.get("roles"), + "drop_system_message": ds_cfg.get("drop_system_message", False), + # we need to add one for detecting sequences with exceeding the `sequence_len` limit. + "max_length": cfg.sequence_len + 1 + if not cfg.reward_model + else cfg.sequence_len, + } + + strategy_params = { + "train_on_inputs": cfg.train_on_inputs, + "sequence_len": cfg.sequence_len, + "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), + "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + } + + strategy = BTChatTemplateStrategy( + ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params + ) + + if "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] + + return strategy diff --git a/src/axolotl/prompt_strategies/bradley_terry/llama3.py b/src/axolotl/prompt_strategies/bradley_terry/llama3.py new file mode 100644 index 0000000000..1d586fd5f4 --- /dev/null +++ b/src/axolotl/prompt_strategies/bradley_terry/llama3.py @@ -0,0 +1,27 @@ +""" +chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template +""" + + +def icr( + cfg, + **kwargs, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + prompt = ( + f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + else: + prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>" + sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 48d52dae11..c7852a707f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -403,6 +403,7 @@ def get_images(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): + # pylint: disable=duplicate-code ds_cfg = ds_cfg or {} prompter_params = { diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 855dbc2d3b..6ad3736557 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -102,7 +102,8 @@ def train( model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) - model.generation_config.do_sample = True + if model.generation_config is not None: + model.generation_config.do_sample = True model_ref = None if cfg.rl and cfg.rl != "orpo": diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3304c62f28..4831da3c8a 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -551,6 +551,7 @@ class Config: resize_token_embeddings_to_32x: Optional[bool] = None rl: Optional[RLType] = None + reward_model: Optional[bool] = None datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore @@ -856,6 +857,17 @@ def hint_sample_packing_padding(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def hint_reward_model_pad(cls, data): + if data.get("reward_model") and not data.get("pad_to_sequence_len"): + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using reward_model" + ) + if data.get("pad_to_sequence_len") is None: + data["pad_to_sequence_len"] = True + return data + @model_validator(mode="before") @classmethod def check_gas_bsz(cls, data): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 163059c2b8..ce01b44098 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -19,6 +19,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, @@ -459,7 +460,7 @@ def for_d_in_datasets(dataset_configs): else: LOG.debug("NOT shuffling merged datasets") - if not cfg.skip_prepare_dataset: + if cfg.sample_packing and not cfg.skip_prepare_dataset: dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: @@ -609,7 +610,20 @@ def get_dataset_wrapper( ) elif cfg.skip_prepare_dataset: dataset_wrapper = dataset - elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + elif ds_strategy := config_dataset.type.startswith( + "bradley_terry" + ) and bradley_terry_load( + config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset + ): + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) + elif ds_strategy := load( + config_dataset.type, tokenizer, cfg, config_dataset, processor=processor + ): if isinstance(ds_strategy, DatasetWrappingStrategy): dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) else: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 30b40925f9..7ebf384aff 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -306,7 +306,11 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): - if not cfg.total_num_tokens and not cfg.skip_prepare_dataset: + if ( + not cfg.total_num_tokens + and not cfg.skip_prepare_dataset + and not cfg.reward_model + ): total_num_tokens = np.sum( train_dataset.data.column("input_ids") .to_pandas() @@ -323,6 +327,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): not skip_estimates and not cfg.total_supervised_tokens and not cfg.skip_prepare_dataset + and not cfg.reward_model ): total_supervised_tokens = ( train_dataset.data.column("labels") diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py new file mode 100644 index 0000000000..27ac3e25f1 --- /dev/null +++ b/tests/e2e/test_reward_model_llama.py @@ -0,0 +1,74 @@ +""" +E2E tests for reward model lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestRewardModelLoraLlama(unittest.TestCase): + """ + Test case for Llama reward models using LoRA + """ + + @with_temp_dir + def test_rm_fft(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "model_type": "AutoModelForSequenceClassification", + "tokenizer_type": "LlamaTokenizer", + "chat_template": "alpaca", + "reward_model": True, + "sequence_len": 1024, + "pad_to_sequence_len": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.0, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "argilla/distilabel-intel-orca-dpo-pairs", + "type": "bradley_terry.chat_template", + }, + ], + "remove_unused_columns": False, + "max_steps": 10, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "gradient_checkpointing": True, + "warmup_ratio": 0.1, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() From ec4272c3a0afedadf7bb54f9386dfd51d4a4c2cb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Oct 2024 17:34:37 -0400 Subject: [PATCH 72/89] add ds zero3 to multigpu biweekly tests (#1900) * add ds zero3 to multigpu biweekly tests * fix for upstream api change * use updated accelerate and fix deepspeed tests * stringify the Path, and run multigpu tests if the multigpu tests change for a PR * use correct json rather than yaml * revert accelerate for deepspeed --- requirements.txt | 2 +- tests/e2e/multigpu/test_llama.py | 114 +++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 37ee1e42cf..46d0691b6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ peft==0.13.2 transformers==4.45.2 tokenizers>=0.20.1 bitsandbytes==0.44.1 -accelerate==1.0.0 +accelerate==0.34.2 datasets==3.0.1 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 61bb8ed327..957a6a9e36 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -19,6 +19,8 @@ LOG = logging.getLogger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + @pytest.fixture(scope="session", autouse=True) def download_model(): @@ -346,3 +348,115 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir): str(Path(temp_dir) / "config.yaml"), ] ) + + @with_temp_dir + def test_ds_zero3_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama_v1.1", + "tokenizer_type": "LlamaTokenizer", + "sample_packing": True, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "sequence_len": 2048, + "val_set_size": 0.05, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 100, + "micro_batch_size": 4, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"), + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) + + @with_temp_dir + def test_ds_zero3_qlora_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "TinyLlama/TinyLlama_v1.1", + "tokenizer_type": "LlamaTokenizer", + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "sample_packing": True, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "sequence_len": 2048, + "val_set_size": 0.05, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 100, + "micro_batch_size": 4, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"), + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) From 335027f155b34b569d2d7c106c8797569a8eaa56 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Oct 2024 20:04:30 -0400 Subject: [PATCH 73/89] upgrade accelerate to 1.0.1 (#1969) --- requirements.txt | 2 +- src/axolotl/train.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 46d0691b6c..8f9f55262e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ peft==0.13.2 transformers==4.45.2 tokenizers>=0.20.1 bitsandbytes==0.44.1 -accelerate==0.34.2 +accelerate==1.0.1 datasets==3.0.1 deepspeed==0.14.4 pydantic==2.6.3 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 6ad3736557..4ce28d8a31 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -10,7 +10,6 @@ import torch import transformers.modelcard -from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model from datasets import Dataset @@ -97,8 +96,6 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) - # we wait unitl the last possible moment to setup Accelerator - Accelerator() model, peft_config = load_model( cfg, tokenizer, processor=processor, inference=cli_args.inference ) From 6d9a3c4d817cd57e702b270c04d2b2d2400c3ad4 Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Mon, 14 Oct 2024 22:00:48 +0200 Subject: [PATCH 74/89] examples: Fix config llama3 (#1833) [skip ci] * update llama3 config * llama3 config --- examples/llama-3/instruct-dpo-lora-8b.yml | 1 - examples/llama-3/instruct-lora-8b.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 14febb810a..dc88350358 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -11,7 +11,6 @@ rl: dpo datasets: - path: fozziethebeat/alpaca_messages_2k_dpo_test type: chat_template.default - chat_template: llama3 field_messages: conversation field_chosen: chosen field_rejected: rejected diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index 4acad59999..ae9a8088c3 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -10,7 +10,6 @@ chat_template: llama3 datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - chat_template: llama3 field_messages: messages message_field_role: role message_field_content: content From 54673fd6ca39bf86addb20ba7d44a77071b4dc4f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 17 Oct 2024 14:12:31 -0400 Subject: [PATCH 75/89] also debug if other debug args are set (#1977) --- src/axolotl/cli/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index fd5ab3e56c..84836bb793 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -462,7 +462,12 @@ def load_datasets( processor=processor, ) - if cli_args.debug or cfg.debug: + if ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or cli_args.debug_num_examples + ): LOG.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( From f62e23737bc54dbef758fa8c58296ee0b1023e7b Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 17 Oct 2024 15:15:29 -0400 Subject: [PATCH 76/89] memoize dataset length for eval sample packing (#1974) * wip on multimodal sample packing support * wip on multimodal packing support * llama-1b-yml * setup logging for test * yml * yml * yml * fix for __len__ for eval sample packing * reverted irrelavant changes * reformatted, reverted log message * reverted unnecessary changes * added e2e multigpu testing for eval sample packing * formatting * fixed e2e test_eval params * fix test_eval e2e multigpu * fix test_eval e2e multigpu * Update tests/e2e/multigpu/test_eval.py Co-authored-by: Wing Lian * Update tests/e2e/multigpu/test_eval.py Co-authored-by: Wing Lian --------- Co-authored-by: Wing Lian --- examples/llama-3/qlora-1b.yml | 77 ++++++++++++ src/axolotl/utils/samplers/multipack.py | 13 +- tests/e2e/multigpu/test_eval.py | 155 ++++++++++++++++++++++++ 3 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 examples/llama-3/qlora-1b.yml create mode 100644 tests/e2e/multigpu/test_eval.py diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml new file mode 100644 index 0000000000..fdfe4aa7c8 --- /dev/null +++ b/examples/llama-3/qlora-1b.yml @@ -0,0 +1,77 @@ +base_model: meta-llama/Llama-3.2-1B + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 205c2894d1..db14a6819e 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -133,6 +133,8 @@ def __init__( self.eff_total_used = 0 self.eff_total_slots = 0 + self.len_across_ranks = None + def set_epoch(self, epoch: int): self.epoch = epoch @@ -195,15 +197,14 @@ def calc_min_len(estimates: list[(int, float)]): LOG.info(f"gather_len_batches: {repr(estimates)}") return math.floor(0.998 * min(estimates)) - min_len_batches = reduce_and_broadcast( - lambda: num, - calc_min_len, - ) + min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) return min_len_batches def __len__(self): - len_batches = self.num_batches() - return self.gather_len_batches(len_batches) + if not self.len_across_ranks: + len_batches = self.num_batches() + self.len_across_ranks = self.gather_len_batches(len_batches) + return self.len_across_ranks def _len_est(self): efficiency = ( diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py new file mode 100644 index 0000000000..65d26bb824 --- /dev/null +++ b/tests/e2e/multigpu/test_eval.py @@ -0,0 +1,155 @@ +""" +E2E tests for multigpu eval +""" +import logging +import os +import unittest +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async + +from axolotl.utils.dict import DictDefault + +from ..utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +class TestMultiGPUEval(unittest.TestCase): + """ + Test case for MultiGPU Eval Sample Packing + """ + + @with_temp_dir + def test_eval_sample_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "load_in_8bit": False, + "load_in_4bit": True, + "strict": False, + "sequence_len": 2048, + "adapter": "qlora", + "sample_packing": True, + "eval_sample_packing": True, + "pad_to_sequence_len": True, + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_modules_to_save": ["embed_tokens", "lm_head"], + "val_set_size": 0.1, + "special_tokens": {"pad_token": "<|end_of_text|>"}, + "datasets": [ + { + "path": "teknium/GPT4-LLM-Cleaned", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "loss_watchdog_threshold": 5.0, + "loss_watchdog_patience": 3, + "bf16": "auto", + "warmup_steps": 1, + "evals_per_epoch": 2, + "eval_max_new_tokens": 128, + "saves_per_epoch": 1, + "logging_steps": 1, + "weight_decay": 0.0, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) + + @with_temp_dir + def test_eval(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "load_in_8bit": False, + "load_in_4bit": True, + "strict": False, + "sequence_len": 2048, + "adapter": "qlora", + "sample_packing": True, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_modules_to_save": ["embed_tokens", "lm_head"], + "val_set_size": 0.1, + "special_tokens": {"pad_token": "<|end_of_text|>"}, + "datasets": [ + { + "path": "teknium/GPT4-LLM-Cleaned", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "loss_watchdog_threshold": 5.0, + "loss_watchdog_patience": 3, + "bf16": "auto", + "warmup_steps": 1, + "evals_per_epoch": 2, + "eval_max_new_tokens": 128, + "saves_per_epoch": 1, + "logging_steps": 1, + "weight_decay": 0.0, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) From 67f744dc8c9564ef7a42d5df780ae53e319dca61 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 18 Oct 2024 03:36:51 -0400 Subject: [PATCH 77/89] add pytorch 2.5.0 base images (#1979) * add pytorch 2.5.0 base images * make sure num examples for debug is zero and fix comparison --- .github/workflows/base.yml | 6 ++++++ src/axolotl/cli/__init__.py | 2 +- src/axolotl/common/cli.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 1b24f2c970..c94093bc97 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -36,6 +36,12 @@ jobs: python_version: "3.11" pytorch: 2.4.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "124" + cuda_version: 12.4.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.5.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" steps: - name: Checkout uses: actions/checkout@v3 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 84836bb793..77bb551f8c 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -466,7 +466,7 @@ def load_datasets( cli_args.debug or cfg.debug or cli_args.debug_text_only - or cli_args.debug_num_examples + or int(cli_args.debug_num_examples) > 0 ): LOG.info("check_dataset_labels...") check_dataset_labels( diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index c96f8f81ff..6a3a22e637 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -23,7 +23,7 @@ class TrainerCliArgs: debug: bool = field(default=False) debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=5) + debug_num_examples: int = field(default=0) inference: bool = field(default=False) merge_lora: bool = field(default=False) prompter: Optional[str] = field(default=None) From e12a2130e990313bb0bce66be8fbbe5b856094dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Oct 2024 11:00:45 -0400 Subject: [PATCH 78/89] first pass at pytorch 2.5.0 support (#1982) * first pass at pytorch 2.5.0 support * attempt to install causal_conv1d with mamba * gracefully handle missing xformers * fix import * fix incorrect version, add 2.5.0 * increase tests timeout --- .github/workflows/main.yml | 10 +++ .github/workflows/multi-gpu-e2e.yml | 13 +++- .github/workflows/nightlies.yml | 10 +++ .github/workflows/tests-nightly.yml | 9 ++- .github/workflows/tests.yml | 10 ++- cicd/Dockerfile.jinja | 1 - cicd/multigpu.py | 2 +- cicd/tests.py | 2 +- docker/Dockerfile | 1 - setup.py | 5 +- .../monkeypatch/llama_attn_hijack_flash.py | 61 ++++++------------- src/axolotl/monkeypatch/xformers_/__init__.py | 51 ++++++++++++++++ 12 files changed, 120 insertions(+), 55 deletions(-) create mode 100644 src/axolotl/monkeypatch/xformers_/__init__.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c27dbedefa..47a4c7f114 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,6 +29,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -86,6 +91,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index ab886c67f1..d9f0ce7e6c 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -21,10 +21,17 @@ jobs: pytorch: 2.3.1 axolotl_extras: num_gpus: 2 - - cuda: 121 - cuda_version: 12.1.1 + - cuda: 124 + cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.3.1 + pytorch: 2.4.1 + axolotl_extras: + num_gpus: 2 + nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 axolotl_extras: num_gpus: 2 nightly_build: "true" diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 17c76c24e7..55123a9026 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -28,6 +28,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -85,6 +90,11 @@ jobs: python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 8c9e1f49e7..56eaae2398 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.1"] + pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] timeout-minutes: 20 steps: @@ -95,6 +95,13 @@ jobs: num_gpus: 1 axolotl_extras: nightly_build: "true" + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + num_gpus: 1 + axolotl_extras: + nightly_build: "true" steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a798bdd5cd..e679f41010 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,7 +36,7 @@ jobs: fail-fast: false matrix: python_version: ["3.10", "3.11"] - pytorch_version: ["2.3.1", "2.4.1"] + pytorch_version: ["2.3.1", "2.4.1", "2.5.0"] timeout-minutes: 20 steps: @@ -72,7 +72,7 @@ jobs: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 60 + timeout-minutes: 90 needs: [pre-commit, pytest] strategy: @@ -97,6 +97,12 @@ jobs: pytorch: 2.4.1 num_gpus: 1 axolotl_extras: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.5.0 + num_gpus: 1 + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 11ce8d8baa..3b082a15b0 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -23,7 +23,6 @@ RUN git fetch origin +$GITHUB_REF && \ git checkout FETCH_HEAD # If AXOLOTL_EXTRAS is set, append it in brackets -RUN pip install causal_conv1d RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ diff --git a/cicd/multigpu.py b/cicd/multigpu.py index be10fbc73a..da726b4731 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str): @stub.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=45 * 60, + timeout=60 * 60, cpu=8.0, memory=131072 * N_GPUS, ) diff --git a/cicd/tests.py b/cicd/tests.py index 9c2d830cb7..9ebce9815f 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str): @stub.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=45 * 60, + timeout=60 * 60, cpu=8.0, memory=131072, ) diff --git a/docker/Dockerfile b/docker/Dockerfile index 2b106f1ed8..4872b3907c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets -RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/setup.py b/setup.py index 7d9568dbff..1153d69681 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,9 @@ def parse_requirements(): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 4): + if (major, minor) >= (2, 5): + _install_requires.pop(_install_requires.index(xformers_version)) + elif (major, minor) >= (2, 4): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") @@ -102,6 +104,7 @@ def parse_requirements(): ], "mamba-ssm": [ "mamba-ssm==1.2.0.post1", + "causal_conv1d", ], "auto-gptq": [ "auto-gptq==0.5.1", diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4c3571ea4f..c804d0c6b9 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -22,7 +22,6 @@ apply_rotary_pos_emb, repeat_kv, ) -from xformers.ops import SwiGLU from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name @@ -44,7 +43,19 @@ LOG = logging.getLogger("axolotl") +def is_xformers_available() -> bool: + try: + import xformers # pylint: disable=unused-import # noqa: F401 + + return True + except ImportError: + return False + + def is_xformers_swiglu_available() -> bool: + if not is_xformers_available(): + return False + from xformers.ops.common import get_xformers_operator try: @@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool: def replace_llama_mlp_with_swiglu(model): + if is_xformers_swiglu_available(): + from axolotl.monkeypatch.xformers_ import FusedMLP + else: + raise RuntimeError("xformers SwiGLU not available for this environment") + for name, module in model.named_modules(): if isinstance(module, LlamaMLP): mlp = FusedMLP( @@ -181,49 +197,6 @@ def _post_training(self, model, name): set_module_name(model, name, new_attn) -class FusedMLP(torch.nn.Module): - """ - Fused MLP layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - gate_proj: torch.nn.Linear, - up_proj: torch.nn.Linear, - down_proj: torch.nn.Linear, - ): - super().__init__() - self.config = config - self.swiglu = SwiGLU( - in_features=config.hidden_size, - hidden_features=config.intermediate_size, - bias=False, - _pack_weights=True, - ) - # overwrite initialized weights with pretrained weights - self.swiglu.w12.weight.data = torch.cat( - (gate_proj.weight.data, up_proj.weight.data), dim=0 - ) - self.swiglu.w3.weight.data = down_proj.weight.data - - def _post_training(self, model, name): - w1, w2 = torch.split( # pylint: disable=invalid-name - self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 - ) - - # Assign the split weights back to the original layers - new_mlp = LlamaMLP(self.config) - new_mlp.gate_proj.weight.data = w1 - new_mlp.up_proj.weight.data = w2 - new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data - - set_module_name(model, name, new_mlp) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return self.swiglu(x) - - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( diff --git a/src/axolotl/monkeypatch/xformers_/__init__.py b/src/axolotl/monkeypatch/xformers_/__init__.py new file mode 100644 index 0000000000..bddc036b24 --- /dev/null +++ b/src/axolotl/monkeypatch/xformers_/__init__.py @@ -0,0 +1,51 @@ +""" +Fused MLP layer for incrementally improved training efficiency +""" +import torch +from transformers.models.llama.modeling_llama import LlamaMLP +from xformers.ops import SwiGLU + +from axolotl.monkeypatch.utils import set_module_name + + +class FusedMLP(torch.nn.Module): + """ + Fused MLP layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + gate_proj: torch.nn.Linear, + up_proj: torch.nn.Linear, + down_proj: torch.nn.Linear, + ): + super().__init__() + self.config = config + self.swiglu = SwiGLU( + in_features=config.hidden_size, + hidden_features=config.intermediate_size, + bias=False, + _pack_weights=True, + ) + # overwrite initialized weights with pretrained weights + self.swiglu.w12.weight.data = torch.cat( + (gate_proj.weight.data, up_proj.weight.data), dim=0 + ) + self.swiglu.w3.weight.data = down_proj.weight.data + + def _post_training(self, model, name): + w1, w2 = torch.split( # pylint: disable=invalid-name + self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 + ) + + # Assign the split weights back to the original layers + new_mlp = LlamaMLP(self.config) + new_mlp.gate_proj.weight.data = w1 + new_mlp.up_proj.weight.data = w2 + new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data + + set_module_name(model, name, new_mlp) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return self.swiglu(x) From 955cca41fc0c8c174be4ff46c4f66937d227848e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Oct 2024 19:50:50 -0400 Subject: [PATCH 79/89] don't explicitly set cpu pytorch version (#1986) use a constraint file use min version of xformers don't install autoawq with pytorch 2.5.0 debugging for errors upgrade pip first fix action yml add back try/except retry w/o constraint use --no-build-isolation show torch version install setuptools and wheel add back try/except --- .github/workflows/tests.yml | 10 +++++++--- requirements.txt | 2 +- setup.py | 7 ++++++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e679f41010..130ac6e7b6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,14 +49,18 @@ jobs: python-version: ${{ matrix.python_version }} cache: 'pip' # caching pip dependencies + - name: upgrade pip + run: | + pip3 install --upgrade pip + pip3 install --upgrade packaging setuptools wheel + - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + pip3 install torch==${{ matrix.pytorch_version }} - name: Install dependencies run: | - pip3 install --upgrade pip - pip3 install --upgrade packaging + pip3 show torch pip3 install -U -e . pip3 install -r requirements-tests.txt diff --git a/requirements.txt b/requirements.txt index 8f9f55262e..067be05cf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ flash-attn==2.6.3 sentencepiece wandb einops -xformers==0.0.28.post1 +xformers>=0.0.23.post1 optimum==1.16.2 hf_transfer colorama diff --git a/setup.py b/setup.py index 1153d69681..17347f0632 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,8 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0] + autoawq_version = [req for req in _install_requires if "autoawq" in req][0] + if "Darwin" in platform.system(): # don't install xformers on MacOS _install_requires.pop(_install_requires.index(xformers_version)) @@ -52,10 +54,14 @@ def parse_requirements(): if (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.pop(_install_requires.index(autoawq_version)) elif (major, minor) >= (2, 4): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") + else: + _install_requires.pop(_install_requires.index(xformers_version)) + _install_requires.append("xformers==0.0.28.post1") elif (major, minor) >= (2, 3): _install_requires.pop(_install_requires.index(torchao_version)) if patch == 0: @@ -75,7 +81,6 @@ def parse_requirements(): except PackageNotFoundError: pass - return _install_requires, _dependency_links From 5c629ee4447b64b77f465bc14ded44d37efdfa9b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Oct 2024 19:51:06 -0400 Subject: [PATCH 80/89] use torch 2.4.1 images as latest now that torch 2.5.0 is out (#1987) --- .github/workflows/main.yml | 4 ++-- .github/workflows/nightlies.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 47a4c7f114..3b82f6a510 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,12 +23,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: mamba-ssm - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -85,12 +85,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 55123a9026..b2110e737d 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -22,12 +22,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -84,12 +84,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" From 9bd5f7d015bb30acd84f7b4bc780123184fcf3b8 Mon Sep 17 00:00:00 2001 From: Adam Hazell <34248583+awhazell@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:52:21 +0100 Subject: [PATCH 81/89] Log checkpoints as mlflow artifacts (#1976) * Ensure hf_mlflow_log_artifact config var is set in env * Add transformer MLflowCallback to callbacks list when mlflow enabled * Test hf_mlflow_log_artifacts is set correctly * Test mlflow not being used by default --- src/axolotl/core/trainer_builder.py | 9 +++-- src/axolotl/utils/mlflow_.py | 4 +++ tests/test_validation.py | 56 +++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 599144bd34..f05efe7b82 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1119,12 +1119,17 @@ def get_callbacks(self) -> List[TrainerCallback]: SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) if self.cfg.use_mlflow and is_mlflow_available(): + from transformers.integrations.integration_utils import MLflowCallback + from axolotl.utils.callbacks.mlflow_ import ( SaveAxolotlConfigtoMlflowCallback, ) - callbacks.append( - SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + callbacks.extend( + [ + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path), + MLflowCallback, + ] ) if self.cfg.use_comet and is_comet_available(): from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py index ce77390342..8710b07d06 100644 --- a/src/axolotl/utils/mlflow_.py +++ b/src/axolotl/utils/mlflow_.py @@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault): # Enable mlflow if experiment name is present if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0: cfg.use_mlflow = True + + # Enable logging hf artifacts in mlflow if value is truthy + if cfg.hf_mlflow_log_artifacts is True: + os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true" diff --git a/tests/test_validation.py b/tests/test_validation.py index 6e0d0ad2a5..fb63977f5c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault +from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -1432,3 +1433,58 @@ def test_comet_sets_env(self, minimal_cfg): for key in comet_env.keys(): os.environ.pop(key, None) + + +class TestValidationMLflow(BaseValidation): + """ + Validation test for MLflow + """ + + def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg): + cfg = ( + DictDefault( + { + "hf_mlflow_log_artifacts": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + + assert new_cfg.hf_mlflow_log_artifacts is True + + # Check it's not already present in env + assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ + + setup_mlflow_env_vars(new_cfg) + + assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true" + + os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None) + + def test_mlflow_not_used_by_default(self, minimal_cfg): + cfg = DictDefault({}) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_mlflow_env_vars(new_cfg) + + assert cfg.use_mlflow is not True + + cfg = ( + DictDefault( + { + "mlflow_experiment_name": "foo", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + + setup_mlflow_env_vars(new_cfg) + + assert new_cfg.use_mlflow is True + + os.environ.pop("MLFLOW_EXPERIMENT_NAME", None) From 718cfb2dd1ff2a03b89e3b95f0b1aa1e04046e6e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 22 Oct 2024 13:54:24 -0400 Subject: [PATCH 82/89] revert image tagged as main-latest (#1990) --- .github/workflows/main.yml | 4 ++-- .github/workflows/nightlies.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3b82f6a510..47a4c7f114 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,12 +23,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: mamba-ssm + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -85,12 +85,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index b2110e737d..55123a9026 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -22,12 +22,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -84,12 +84,12 @@ jobs: python_version: "3.11" pytorch: 2.3.1 axolotl_extras: + is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.4.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" From 1d6a5e2bd638778a42d757ff0cb600f918eb1c31 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Fri, 25 Oct 2024 21:06:56 +0800 Subject: [PATCH 83/89] Refactor func load_model to class ModelLoader (#1909) --- cicd/cicd.sh | 2 +- src/axolotl/utils/models.py | 1136 +++++++++++++++++++--------------- tests/e2e/test_load_model.py | 95 +++ tests/utils/test_models.py | 91 ++- 4 files changed, 826 insertions(+), 498 deletions(-) create mode 100644 tests/e2e/test_load_model.py diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 104a8f84ab..483d62a7ad 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -1,6 +1,6 @@ #!/bin/bash set -e -pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ +pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c18af9760f..5e53df72cb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -324,671 +324,823 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): return processor -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: +class ModelLoader: """ - Load a model for a given configuration and tokenizer. + ModelLoader: managing all the config and monkey patches while loading model """ - base_model = cfg.base_model - model_type = cfg.type_of_model - model_config = load_model_config(cfg) - - # load any patches from plugins - from axolotl.integrations.base import PluginManager + def __init__( + self, + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument + ) -> None: + self.cfg = cfg + self.tokenizer = tokenizer + self.inference: bool = inference + self.reference_model: bool = reference_model + + # init model kwargs + self.model_kwargs: Dict[str, Any] = {} + if cfg.model_kwargs: + for key, val in cfg.model_kwargs.items(): + self.model_kwargs[key] = val + + # init model + self.model: PreTrainedModel + self.base_model = cfg.base_model + self.model_type = cfg.type_of_model + + # init model config + self.model_config = load_model_config(cfg) + if cfg.is_multimodal: + self.text_model_config = self.model_config.text_config + else: + self.text_model_config = self.model_config - plugin_manager = PluginManager.get_instance() - plugin_manager.pre_model_load(cfg) + self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name - if cfg.is_multimodal: - text_model_config = model_config.text_config - else: - text_model_config = model_config + def apply_patches(self) -> None: + # load any patches from plugins + from axolotl.integrations.base import PluginManager - # TODO refactor as a kwarg - load_in_8bit = cfg.load_in_8bit + plugin_manager = PluginManager.get_instance() + plugin_manager.pre_model_load(self.cfg) - if cfg.gradient_checkpointing == "unsloth": - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if self.cfg.gradient_checkpointing == "unsloth": + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper - if hasattr(model_config, "model_type") and model_config.model_type == "mllama": - if cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama + if self.cfg.flash_attention: + self.patch_attention() - patch_mllama() + if self.cfg.sample_packing and self.cfg.s2_attention: + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) - if hasattr(model_config, "model_type") and model_config.model_type == "btlm": - if cfg.flash_attention: - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, + if ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and self.cfg.flash_attention + and self.cfg.sample_packing + ): + patch_for_multipack( + self.cfg.model_config_type, + model_name=self.cfg.base_model, + is_remote_code=self.cfg.trust_remote_code, ) - replace_btlm_attn_with_flash_attn(cfg.base_model) + if self.cfg.is_llama_derived_model: + self.patch_loss() + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - if ( - hasattr(model_config, "model_type") - and model_config.model_type == "stablelm_epoch" - ): - if cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, + patch_self_attn_lora() + elif self.cfg.is_llama_derived_model: + self.patch_llama_derived_model() + + if ( + self.cfg.model_config_type == "mistral" + and self.cfg.flash_attn_cross_entropy_loss + ): + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + patch_mistral_cross_entropy, ) - replace_stablelm_attn_with_flash_attn(cfg.base_model) + patch_mistral_cross_entropy() - if cfg.sample_packing and cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) + def patch_attention(self) -> None: + if hasattr(self.model_config, "model_type"): + if self.model_config.model_type == "mllama" and self.cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama - if ( - cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and cfg.flash_attention - and cfg.sample_packing - ): - patch_for_multipack( - cfg.model_config_type, - model_name=cfg.base_model, - is_remote_code=cfg.trust_remote_code, - ) + patch_mllama() - if cfg.is_llama_derived_model: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_llama_cross_entropy, - patch_llama_rms_norm, - ) + if self.model_config.model_type == "btlm": + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) - if cfg.flash_attn_cross_entropy: - patch_llama_cross_entropy() - if cfg.flash_attn_rms_norm: - patch_llama_rms_norm() - elif cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - if cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import ( - integrate_cross_entropy_loss_patch, + replace_btlm_attn_with_flash_attn(self.cfg.base_model) + + if ( + self.model_config.model_type == "stablelm_epoch" + and self.cfg.sample_packing + ): + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, ) - integrate_cross_entropy_loss_patch(model_type="llama") - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + + def patch_loss(self) -> None: + """ + Patch loss functions + """ + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if self.cfg.flash_attn_cross_entropy: + patch_llama_cross_entropy() + if self.cfg.flash_attn_rms_norm: + patch_llama_rms_norm() + elif self.cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + if self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - patch_self_attn_lora() - elif cfg.is_llama_derived_model: - # Modify all llama derived models in one block + integrate_cross_entropy_loss_patch(model_type="llama") + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() - if cfg.flash_attention: + def patch_llama_derived_model(self) -> None: + """ + Modify all llama derived models in one block + """ + + if self.cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - if cfg.sample_packing: - if cfg.device not in ["mps", "cpu"] and not inference: + if self.cfg.sample_packing: + if self.cfg.device not in ["mps", "cpu"] and not self.inference: LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn( packed=True, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.s2_attention: + elif self.cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) - elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm: + elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: replace_llama_attn_with_flash_attn( packed=False, - cross_entropy=cfg.flash_attn_cross_entropy, - rms_norm=cfg.flash_attn_rms_norm, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, ) - elif cfg.xformers_attention: + elif self.cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( hijack_llama_attention, ) LOG.info("patching with xformers attention") hijack_llama_attention() - elif cfg.sample_packing: + elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) LOG.info("patching llama _prepare_4d_causal_attention_mask*") hijack_llama_prepare_4d_mask() - elif cfg.s2_attention: + elif self.cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) - if cfg.unsloth_cross_entropy_loss: + if self.cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch integrate_cross_entropy_loss_patch(model_type="llama") - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora patch_self_attn_lora() - # Modify mistral derived models - if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss: - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - patch_mistral_cross_entropy, - ) - - patch_mistral_cross_entropy() - - model_kwargs: Dict[str, Any] = {} - - if cfg.model_kwargs: - for key, val in cfg.model_kwargs.items(): - model_kwargs[key] = val + def set_auto_model_loader(self) -> None: + """set self.AutoModelLoader + - default value: AutoModelForCausalLM (set at __init__) + - when using a multi modality model, self.AutoModelLoader should + be set according to model type of the model + """ + if self.cfg.is_multimodal: + if self.model_config.model_type == "llava": + self.AutoModelLoader = ( # pylint: disable=invalid-name + LlavaForConditionalGeneration + ) + elif self.model_config.model_type == "mllama": + self.AutoModelLoader = ( # pylint: disable=invalid-name + MllamaForConditionalGeneration + ) + else: + self.AutoModelLoader = ( + AutoModelForVision2Seq # pylint: disable=invalid-name + ) - max_memory = cfg.max_memory - device_map = cfg.device_map + def set_device_map_config(self) -> None: + device_map = self.cfg.device_map + max_memory = self.cfg.max_memory - AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name - if cfg.is_multimodal: - if model_config.model_type == "llava": - AutoModelLoader = ( # pylint: disable=invalid-name - LlavaForConditionalGeneration - ) - elif model_config.model_type == "mllama": - AutoModelLoader = ( # pylint: disable=invalid-name - MllamaForConditionalGeneration + if self.cfg.gpu_memory_limit: + gpu_memory_limit = ( + str(self.cfg.gpu_memory_limit) + "GiB" + if isinstance(self.cfg.gpu_memory_limit, int) + else self.cfg.gpu_memory_limit ) - else: - AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name - - if cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(cfg.gpu_memory_limit) + "GiB" - if isinstance(cfg.gpu_memory_limit, int) - else cfg.gpu_memory_limit - ) - max_memory = {} - for i in range(torch.cuda.device_count()): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything + max_memory = {} + for i in range(torch.cuda.device_count()): + max_memory[i] = gpu_memory_limit + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map + if max_memory is not None: + # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py + from accelerate import infer_auto_device_map - with init_empty_weights(): - model_canvas = AutoModelLoader.from_config( - model_config, trust_remote_code=cfg.trust_remote_code or False + with init_empty_weights(): + model_canvas = self.AutoModelLoader.from_config( + self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + model_canvas.tie_weights() + device_map = infer_auto_device_map( + model_canvas, + max_memory=max_memory, + dtype=self.cfg.torch_dtype, ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - model_kwargs["device_map"] = device_map - model_kwargs["torch_dtype"] = cfg.torch_dtype - - if torch.backends.mps.is_available(): - model_kwargs["device_map"] = "mps:0" - - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - - if is_deepspeed_zero3_enabled(): - del model_kwargs["device_map"] + # We can discard max_memory now as we have a device map set up for us + max_memory = None + + self.model_kwargs["device_map"] = device_map + self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + + if torch.backends.mps.is_available(): + self.model_kwargs["device_map"] = "mps:0" + + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + + if is_deepspeed_zero3_enabled(): + del self.model_kwargs["device_map"] + + def set_quantization_config(self) -> None: + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.gptq: + if not hasattr(self.model_config, "quantization_config"): + LOG.warning( + "model config does not contain quantization_config information" + ) + else: + if self.cfg.gptq_disable_exllama is not None: + self.model_config.quantization_config[ + "disable_exllama" + ] = self.cfg.gptq_disable_exllama + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + if ( + self.cfg.adapter in ["qlora", "lora"] + and hasattr(self.model_config, "quantization_config") + and self.model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if self.model_config.quantization_config["quant_method"] == "gptq": + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + elif self.model_config.quantization_config["quant_method"] == "awq": + self.model_kwargs["quantization_config"] = AwqConfig( + **self.model_config.quantization_config + ) + elif ( + self.model_config.quantization_config["quant_method"] == "bitsandbytes" + ): + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **self.model_config.quantization_config + ) + elif self.cfg.adapter == "qlora" and ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ): + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": self.cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, + } + if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + self.cfg.deepspeed or self.cfg.fsdp + ): + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 - if cfg.revision_of_model: - model_kwargs["revision"] = cfg.revision_of_model + if self.cfg.bnb_config_kwargs: + bnb_config.update(self.cfg.bnb_config_kwargs) - if cfg.gptq: - if not hasattr(model_config, "quantization_config"): - LOG.warning("model config does not contain quantization_config information") - else: - if cfg.gptq_disable_exllama is not None: - model_config.quantization_config[ - "disable_exllama" - ] = cfg.gptq_disable_exllama - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, ) - if ( - cfg.adapter in ["qlora", "lora"] - and hasattr(model_config, "quantization_config") - and model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] - ): - if model_config.quantization_config["quant_method"] == "gptq": - model_kwargs["quantization_config"] = GPTQConfig( - **model_config.quantization_config + elif self.cfg.adapter == "lora" and ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ): + bnb_config = { + "load_in_8bit": True, + } + # Exclude mamba blocks from int8 quantization for jamba + if self.cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, ) - elif model_config.quantization_config["quant_method"] == "awq": - model_kwargs["quantization_config"] = AwqConfig( - **model_config.quantization_config + + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + if "quantization_config" in self.model_kwargs or self.cfg.gptq: + if "load_in_8bit" in self.model_kwargs: + del self.model_kwargs["load_in_8bit"] + if "load_in_4bit" in self.model_kwargs: + del self.model_kwargs["load_in_4bit"] + + def set_attention_config(self) -> None: + """ + sample packing uses custom FA2 patch + """ + if self.cfg.flash_attention: + if not self.cfg.sample_packing and self.cfg.s2_attention: + pass + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" ) - elif model_config.quantization_config["quant_method"] == "bitsandbytes": - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **model_config.quantization_config + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" ) - elif cfg.adapter == "qlora" and cfg.load_in_4bit: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - cfg.deepspeed or cfg.fsdp - ): - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if cfg.bnb_config_kwargs: - bnb_config.update(cfg.bnb_config_kwargs) - - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif cfg.adapter == "lora" and cfg.load_in_8bit: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - if cfg.load_in_8bit and cfg.adapter is not None: - model_kwargs["load_in_8bit"] = True - if cfg.load_in_4bit and cfg.adapter is not None: - model_kwargs["load_in_4bit"] = True - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in model_kwargs or cfg.gptq: - if "load_in_8bit" in model_kwargs: - del model_kwargs["load_in_8bit"] - if "load_in_4bit" in model_kwargs: - del model_kwargs["load_in_4bit"] - - # sample packing uses custom FA2 patch - if cfg.flash_attention: - if not cfg.sample_packing and cfg.s2_attention: - pass - model_kwargs["attn_implementation"] = "flash_attention_2" - model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif cfg.sdp_attention: - model_kwargs["attn_implementation"] = "sdpa" - model_config._attn_implementation = "sdpa" # pylint: disable=protected-access - elif cfg.eager_attention: - model_kwargs["attn_implementation"] = "eager" - model_config._attn_implementation = "eager" # pylint: disable=protected-access - - if cfg.low_cpu_mem_usage: - model_kwargs["low_cpu_mem_usage"] = True - qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" + if self.cfg.low_cpu_mem_usage: + self.model_kwargs["low_cpu_mem_usage"] = True - try: + def build_model(self, qlora_fsdp) -> bool: skip_move_to_device = False if ( # pylint: disable=condition-evals-to-constant) - (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) and not qlora_fsdp and False ): - model = load_sharded_model( - base_model, - model_config, - cfg, - torch_dtype=cfg.torch_dtype, + self.model = load_sharded_model( + self.base_model, + self.model_config, + self.cfg, + torch_dtype=self.cfg.torch_dtype, ) skip_move_to_device = True elif ( qlora_fsdp - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading) + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and ( + self.cfg.model_config_type == "dbrx" + or self.cfg.qlora_sharded_model_loading + ) ): - quant_storage = cfg.torch_dtype + quant_storage = self.cfg.torch_dtype quantization_config = hasattr( - model_config, "quantization_config" - ) and getattr(model_config, "quantization_config") + self.model_config, "quantization_config" + ) and getattr(self.model_config, "quantization_config") quantization_config = ( - quantization_config or model_kwargs["quantization_config"] + quantization_config or self.model_kwargs["quantization_config"] ) - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = load_sharded_model_quant( - base_model, - model_config, - cfg, + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = load_sharded_model_quant( + self.base_model, + self.model_config, + self.cfg, quant_storage=quant_storage, quantization_config=quantization_config, ) skip_move_to_device = True elif ( - model_config.model_type == "llama" - and not cfg.trust_remote_code - and not cfg.gptq + self.model_config.model_type == "llama" + and not self.cfg.trust_remote_code + and not self.cfg.gptq ): - if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] - - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - **model_kwargs, + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, ) - if cfg.flash_attention and not inference: + # TODO (MengqingCao) split these patches seperately + if self.cfg.flash_attention and not self.inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( is_xformers_swiglu_available, replace_llama_mlp_with_swiglu, replace_llama_qkv_with_fused, ) - if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(model) + replace_llama_mlp_with_swiglu(self.model) - if cfg.flash_attn_fuse_qkv: + if self.cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(model) - elif model_type == "MambaLMHeadModel": + replace_llama_qkv_with_fused(self.model) + elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - model_kwargs["dtype"] = model_kwargs["torch_dtype"] - model_kwargs["device"] = torch.cuda.current_device() - del model_kwargs["torch_dtype"] - del model_kwargs["device_map"] + self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] + self.model_kwargs["device"] = torch.cuda.current_device() + del self.model_kwargs["torch_dtype"] + del self.model_kwargs["device_map"] - model = MambaLMHeadModel.from_pretrained( - base_model, - **model_kwargs, + self.model = MambaLMHeadModel.from_pretrained( + self.base_model, + **self.model_kwargs, ) elif ( - model_type - and model_type != "AutoModelForCausalLM" - and not cfg.trust_remote_code + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code ): - if cfg.gptq: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + if self.cfg.gptq: + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = getattr(transformers, model_type).from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.model = getattr(transformers, self.model_type).from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # when training starts if ( - hasattr(text_model_config, "max_seq_len") - and text_model_config.max_seq_len - and cfg.sequence_len > model_config.max_seq_len + hasattr(self.text_model_config, "max_seq_len") + and self.text_model_config.max_seq_len + and self.cfg.sequence_len > self.text_model_config.max_seq_len ): - text_model_config.max_seq_len = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") + self.text_model_config.max_seq_len = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") elif ( - hasattr(text_model_config, "max_sequence_length") - and text_model_config.max_sequence_length - and cfg.sequence_len > text_model_config.max_sequence_length + hasattr(self.text_model_config, "max_sequence_length") + and self.text_model_config.max_sequence_length + and self.cfg.sequence_len > self.text_model_config.max_sequence_length ): - text_model_config.max_sequence_length = cfg.sequence_len - LOG.warning(f"increasing context length to {cfg.sequence_len}") - if cfg.gptq: - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + self.text_model_config.max_sequence_length = self.cfg.sequence_len + LOG.warning(f"increasing context length to {self.cfg.sequence_len}") + if self.cfg.gptq: + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) else: - if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if ( + self.cfg.fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): # disabling either of these two still leads to VRAM spike before setting back down skip_move_to_device = True - if "device_map" in model_kwargs: - del model_kwargs["device_map"] - - if cfg.is_multimodal: - model_config.text_config = text_model_config - model = AutoModelLoader.from_pretrained( - base_model, - config=model_config, - trust_remote_code=cfg.trust_remote_code or False, - **model_kwargs, + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + if self.cfg.is_multimodal: + self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, ) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - model = model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(tokenizer) / 32) * 32 - if cfg.resize_token_embeddings_to_32x - else len(tokenizer) - ) - if ( - hasattr(model, "get_input_embeddings") - and model.get_input_embeddings().num_embeddings < embeddings_len - ): - model.resize_token_embeddings(embeddings_len) - else: - model.tie_weights() + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True - if ( - hasattr(model, "config") - and hasattr(model.config, "max_position_embeddings") - and model.config.max_position_embeddings - and cfg.sequence_len > model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" - ) - model.config.max_position_embeddings = cfg.sequence_len + return skip_move_to_device - if ( - hasattr(model, "config") - and hasattr(model.config, "bos_token_id") - and model.config.bos_token_id - and model.config.bos_token_id != tokenizer.bos_token_id - ): - model.config.bos_token_id = tokenizer.bos_token_id + def ajust_model_config(self) -> None: + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "max_position_embeddings") + and self.model.config.max_position_embeddings + and self.cfg.sequence_len > self.model.config.max_position_embeddings + ): + LOG.warning( + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + self.model.config.max_position_embeddings = self.cfg.sequence_len - if ( - hasattr(model, "config") - and hasattr(model.config, "eos_token_id") - and model.config.eos_token_id - and model.config.eos_token_id != tokenizer.eos_token_id - ): - model.config.eos_token_id = tokenizer.eos_token_id - - if hasattr(model, "device") and model.device.type in ("cuda", "mps"): - log_gpu_memory_usage(LOG, "after model load", model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(cfg.model_config_type) - if not cfg.fsdp: - # FSDP doesn't like mixed Float and BFloat16 - for name, module in model.named_modules(): - if "norm" in name or name.endswith(".gate"): - module.to(torch.float32) - if model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(torch.float32) + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "bos_token_id") + and self.model.config.bos_token_id + and self.model.config.bos_token_id != self.tokenizer.bos_token_id + ): + self.model.config.bos_token_id = self.tokenizer.bos_token_id - needs_fa2_dtype = cfg.adapter or cfg.fsdp - skip_prepare_model_for_kbit_training = False + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "eos_token_id") + and self.model.config.eos_token_id + and self.model.config.eos_token_id != self.tokenizer.eos_token_id + ): + self.model.config.eos_token_id = self.tokenizer.eos_token_id - if is_deepspeed_zero3_enabled(): + def set_z3_leaf_modules(self) -> None: from deepspeed.utils import ( # pylint: disable=no-name-in-module set_z3_leaf_modules, ) - if cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type] + if self.cfg.model_config_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks set_z3_leaf_modules( - model, + self.model, [ - get_module_class_from_name(model, module_name) + get_module_class_from_name(self.model, module_name) for module_name in moe_blocks ], ) - if cfg.model_config_type == "qwen" and cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True + def prepare_model(self, qlora_fsdp) -> None: + skip_prepare_model_for_kbit_training = False + if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True + loftq_bits = ( + self.cfg.peft + and self.cfg.peft.loftq_config + and self.cfg.peft.loftq_config.loftq_bits + ) + if self.cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True - if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True + if qlora_fsdp or ( + self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): + # make sure everything is in the same dtype + skip_prepare_model_for_kbit_training = True - if is_deepspeed_zero3_enabled(): - skip_prepare_model_for_kbit_training = True + if is_deepspeed_zero3_enabled(): + skip_prepare_model_for_kbit_training = True + + is_load_in_8bit = ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ) + is_load_in_4bit = ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ) - if cfg.adapter in ["lora", "qlora"]: - if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs - ) if ( - cfg.load_in_8bit or cfg.load_in_4bit - ) and not skip_prepare_model_for_kbit_training: + not skip_prepare_model_for_kbit_training + and self.cfg.adapter in ["lora", "qlora"] + and (is_load_in_8bit or is_load_in_4bit) + ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - model = prepare_model_for_kbit_training( - model, use_gradient_checkpointing=cfg.gradient_checkpointing + self.model = prepare_model_for_kbit_training( + self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) - needs_fa2_dtype = True - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: - LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) - for name, module in model.named_modules(): + def convert_embedding_modules_dtype( + self, embedding_modules, dist_dtype, before_kbit_train_or_finetune + ) -> None: + for name, module in self.model.named_modules(): if "norm" in name: - module.to(cfg.torch_dtype) + module.to(dist_dtype) + if before_kbit_train_or_finetune: + if name.endswith(".gate"): + module.to(dist_dtype) + if self.model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue if any(m in name for m in embedding_modules): if hasattr(module, "weight"): - module.to(cfg.torch_dtype) - - lora_config = None - if not reference_model or cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto"] and not cfg.merge_lora: - _, lora_config = load_lora(model, cfg, inference=False, config_only=True) + module.to(dist_dtype) + + def apply_lora_patch(self) -> None: + if self.cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(self.model) + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(self.model, self.cfg) + if self.cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + self.apply_patches() + self.set_auto_model_loader() + self.set_device_map_config() + if self.cfg.revision_of_model: + self.model_kwargs["revision"] = self.cfg.revision_of_model + self.set_quantization_config() + self.set_attention_config() + + qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" + skip_move_to_device = False + + try: + skip_move_to_device = self.build_model(qlora_fsdp) + except Exception as err: # pylint: disable=broad-exception-caught + LOG.exception(err) + raise err + + if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: + self.model = self.model.merge_and_unload() + + embeddings_len = ( + math.ceil(len(self.tokenizer) / 32) * 32 + if self.cfg.resize_token_embeddings_to_32x + else len(self.tokenizer) + ) + if ( + hasattr(self.model, "get_input_embeddings") + and self.model.get_input_embeddings().num_embeddings < embeddings_len + ): + self.model.resize_token_embeddings(embeddings_len) else: - model, lora_config = load_adapter(model, cfg, cfg.adapter) + self.model.tie_weights() + + self.ajust_model_config() + + # log device memory usage + if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"): + log_gpu_memory_usage(LOG, "after model load", self.model.device) + + # make sure these are fp32 per Ramesh et al. (2021) + embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) + if not self.cfg.fsdp: + # FSDP doesn't like mixed Float and BFloat16 + self.convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=torch.float32, + before_kbit_train_or_finetune=True, + ) - if is_deepspeed_zero3_enabled(): - skip_move_to_device = True + if is_deepspeed_zero3_enabled(): + self.set_z3_leaf_modules() - if ( - cfg.ddp - and not load_in_8bit - and not (cfg.rl and cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - model.to(f"cuda:{cfg.local_rank}") + needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + if self.cfg.adapter in ["lora", "qlora"]: + needs_fa2_dtype = True + if self.cfg.gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) + + self.prepare_model(qlora_fsdp) - if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(model, "is_parallelizable", True) - setattr(model, "model_parallel", True) + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp: + LOG.info( + "converting modules to %s for flash attention", self.cfg.torch_dtype + ) + self.convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=self.cfg.torch_dtype, + before_kbit_train_or_finetune=False, + ) + + # --------------------------------------------------------- + # load lora or adapter + # --------------------------------------------------------- + lora_config = None + if not self.reference_model or self.cfg.lora_model_dir: + # if we're not loading the reference model, then we're loading the model for training + # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config + if ( + self.cfg.adapter + and self.cfg.rl in ["dpo", "ipo", "kto"] + and not self.cfg.merge_lora + ): + _, lora_config = load_lora( + self.model, self.cfg, inference=False, config_only=True + ) + else: + self.model, lora_config = load_adapter( + self.model, self.cfg, self.cfg.adapter + ) - requires_grad = [] - for name, param in model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - if hasattr(model, "config"): - model.config.use_cache = False + # --------------------------------------------------------- + # put model to accelerator + # --------------------------------------------------------- + is_load_in_8bit = ( + "load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"] + ) + is_load_in_4bit = ( + "load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"] + ) + if ( + self.cfg.ddp + and not is_load_in_8bit + and not (self.cfg.rl and is_load_in_4bit) + and not skip_move_to_device + ): + # TODO revaldate this conditional + self.model.to(f"cuda:{self.cfg.local_rank}") - if cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + setattr(self.model, "is_parallelizable", True) + setattr(self.model, "model_parallel", True) - model = BetterTransformer.transform(model) + # --------------------------------------------------------- + # parameters that require gradient updates + # --------------------------------------------------------- + requires_grad = [] + for name, param in self.model.named_parameters(recurse=True): + if param.requires_grad: + requires_grad.append(f"{name}: {param.requires_grad}") + if len(requires_grad) == 0: + LOG.warning("there are no parameters that require gradient updates") + if hasattr(self.model, "config"): + self.model.config.use_cache = False - if cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", model.device) + if self.cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer - if cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + self.model = BetterTransformer.transform(self.model) - integrate_lora_mlp_patch(model) - if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + if self.cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", self.model.device) - integrate_lora_patch(model, cfg) + self.apply_lora_patch() - if cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() - integrate_rope_embeddings() + # TODO resume_from_checkpoint handling + return self.model, lora_config - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - # TODO resume_from_checkpoint handling - return model, lora_config +def load_model( + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + processor: ProcessorMixin = None, # pylint: disable=unused-argument + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument +) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + """ + Load a model for a given configuration and tokenizer. + """ + loader = ModelLoader( + cfg, + tokenizer, + processor=processor, + inference=inference, + reference_model=reference_model, + **kwargs, + ) + return loader.load_model() def load_adapter(model, cfg, adapter, inference=False): diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py new file mode 100644 index 0000000000..31a9b1a878 --- /dev/null +++ b/tests/e2e/test_load_model.py @@ -0,0 +1,95 @@ +"""Module for testing ModelLoader.""" + +import shutil +import tempfile + +import pytest +import torch + +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import ModelLoader, load_model, load_tokenizer + + +@pytest.fixture(name="temp_dir") +def fixture_temp_dir(): + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + +class TestLoadModelUtils: + """ + Testing module testing ModelLoader. + """ + + def setup_method(self): + # load config + self.cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "tokenizer_config": "JackFram/llama-68m", + "sequence_len": 1024, + "load_in_8bit": False, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + } + ) + self.model_loader = ( # pylint: disable=attribute-defined-outside-init + ModelLoader( + cfg=self.cfg, + tokenizer="", + ) + ) + + @pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"]) + @pytest.mark.parametrize( + "dist_dtype", [torch.bfloat16, torch.float16, torch.float32] + ) + @pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False]) + def test_convert_embedding_modules_dtype( + self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune + ): + self.cfg.output_dir = temp_dir + self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all + self.model_loader.model, _ = load_model( + self.cfg, + self.model_loader.tokenizer, + inference=False, + reference_model=True, + ) + self.model_loader.convert_embedding_modules_dtype( + embedding_modules, dist_dtype, before_kbit_train_or_finetune + ) + for name, module in self.model_loader.model.named_modules(): + if ( + "norm" in name + or (before_kbit_train_or_finetune and name.endswith(".gate")) + or ( + any(m in name for m in embedding_modules) + and hasattr(module, "weight") + ) + ): + for _, param in module.named_parameters(): + assert param.dtype == dist_dtype diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index e06bb6c250..31698f05fb 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -1,18 +1,64 @@ """Module for testing models utils file.""" - -import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils.import_utils import is_torch_mps_available from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model +from axolotl.utils.models import ModelLoader, load_model -class ModelsUtilsTest(unittest.TestCase): +class TestModelsUtils: """Testing module for models utils.""" + def setup_method(self) -> None: + # load config + self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init + { + "base_model": "JackFram/llama-68m", + "model_type": "LlamaForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "load_in_8bit": True, + "load_in_4bit": False, + "adapter": "lora", + "flash_attention": False, + "sample_packing": True, + "device_map": "auto", + } + ) + self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init + spec=PreTrainedTokenizerBase + ) + self.inference = False # pylint: disable=attribute-defined-outside-init + self.reference_model = True # pylint: disable=attribute-defined-outside-init + + # init ModelLoader + self.model_loader = ( # pylint: disable=attribute-defined-outside-init + ModelLoader( + cfg=self.cfg, + tokenizer=self.tokenizer, + inference=self.inference, + reference_model=self.reference_model, + ) + ) + + def test_set_device_map_config(self): + # check device_map + device_map = self.cfg.device_map + if is_torch_mps_available(): + device_map = "mps" + self.model_loader.set_device_map_config() + if is_deepspeed_zero3_enabled(): + assert "device_map" not in self.model_loader.model_kwargs + else: + assert device_map in self.model_loader.model_kwargs["device_map"] + + # check torch_dtype + assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): cfg = DictDefault( { @@ -35,3 +81,38 @@ def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): "shifted-sparse attention does not currently support sample packing" in str(exc.value) ) + + @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) + @pytest.mark.parametrize("load_in_8bit", [True, False]) + @pytest.mark.parametrize("load_in_4bit", [True, False]) + @pytest.mark.parametrize("gptq", [True, False]) + def test_set_quantization_config( + self, + adapter, + load_in_8bit, + load_in_4bit, + gptq, + ): + # init cfg as args + self.cfg.load_in_8bit = load_in_8bit + self.cfg.load_in_4bit = load_in_4bit + self.cfg.gptq = gptq + self.cfg.adapter = adapter + + self.model_loader.set_quantization_config() + if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: + assert not ( + hasattr(self.model_loader.model_kwargs, "load_in_8bit") + and hasattr(self.model_loader.model_kwargs, "load_in_4bit") + ) + elif load_in_8bit and self.cfg.adapter is not None: + assert self.model_loader.model_kwargs["load_in_8bit"] + elif load_in_4bit and self.cfg.adapter is not None: + assert self.model_loader.model_kwargs["load_in_4bit"] + + if (self.cfg.adapter == "qlora" and load_in_4bit) or ( + self.cfg.adapter == "lora" and load_in_8bit + ): + assert self.model_loader.model_kwargs.get( + "quantization_config", BitsAndBytesConfig + ) From 2501c1a6a3392b658fcd5d5ace3d5fb71b633afa Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 25 Oct 2024 22:28:23 +0700 Subject: [PATCH 84/89] Fix: Gradient Accumulation issue (#1980) * feat: support new arg num_items_in_batch * use kwargs to manage extra unknown kwargs for now * upgrade against upstream transformers main * make sure trl is on latest too * fix for upgraded trl * fix: handle trl and transformer signature change * feat: update trl to handle transformer signature * RewardDataCollatorWithPadding no longer has max_length * handle updated signature for tokenizer vs processor class * invert logic for tokenizer vs processor class * processing_class, not processor class * also handle processing class in dpo * handle model name w model card creation * upgrade transformers and add a loss check test * fix install of tbparse requirements * make sure to add tbparse to req * feat: revert kwarg to positional kwarg to be explicit --------- Co-authored-by: Wing Lian --- .github/workflows/pypi.yml | 2 +- .github/workflows/tests-nightly.yml | 3 +- .github/workflows/tests.yml | 2 +- cicd/Dockerfile.jinja | 3 +- requirements-dev.txt | 1 + requirements.txt | 4 +- src/axolotl/core/trainer_builder.py | 72 +++++++++++++--- src/axolotl/monkeypatch/unsloth_.py | 85 +++++-------------- src/axolotl/train.py | 6 +- tests/e2e/patched/test_unsloth_integration.py | 12 +-- tests/e2e/test_packing_loss.py | 74 ++++++++++++++++ 11 files changed, 168 insertions(+), 96 deletions(-) create mode 100644 tests/e2e/test_packing_loss.py diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 885239d185..04dbc6385c 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -27,7 +27,7 @@ jobs: run: | pip3 install wheel packaging pip3 install -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Extract tag name id: tag diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 56eaae2398..90b1e23cd2 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -47,13 +47,14 @@ jobs: sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt + sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt - name: Install dependencies run: | pip3 install --upgrade pip pip3 install --upgrade packaging pip3 install -U -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 130ac6e7b6..ba50adfd35 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: run: | pip3 show torch pip3 install -U -e . - pip3 install -r requirements-tests.txt + pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests run: | diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 3b082a15b0..8ce6550056 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -27,6 +27,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ + sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \ fi RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ @@ -36,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ fi # So we can test the Docker image -RUN pip install -r requirements-tests.txt +RUN pip install -r requirements-dev.txt -r requirements-tests.txt # fix so that git fetch/pull from remote works RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ diff --git a/requirements-dev.txt b/requirements-dev.txt index 4b5df167b6..dcc729d1b2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pre-commit black mypy types-requests +tbparse diff --git a/requirements.txt b/requirements.txt index 067be05cf2..b6e9a554e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.2 -transformers==4.45.2 +transformers==4.46.0 tokenizers>=0.20.1 bitsandbytes==0.44.1 accelerate==1.0.1 @@ -43,7 +43,7 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl==0.9.6 +trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f05efe7b82..319ea7be59 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -7,6 +7,7 @@ import gc import importlib import importlib.util +import inspect import logging import math import os @@ -27,7 +28,6 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( EarlyStoppingCallback, - PreTrainedModel, Trainer, TrainerCallback, TrainingArguments, @@ -666,7 +666,9 @@ def get_bench_dataloader( return DataLoader(bench_dataset, **dataloader_params) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): # use one's weighted cross entropy loss calc # if self.args.sample_packing: # labels = inputs.pop("labels") @@ -674,8 +676,18 @@ def compute_loss(self, model, inputs, return_outputs=False): # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss if self.args.orpo_alpha: - return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) - return super().compute_loss(model, inputs, return_outputs=return_outputs) + return self.orpo_compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + return super().compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): @@ -771,7 +783,13 @@ def orpo_compute_logps( ).squeeze(2) return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) - def orpo_compute_loss(self, model, inputs, return_outputs=False): + def orpo_compute_loss( + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, # pylint: disable=unused-argument + ): concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( inputs, label_pad_token=-100, @@ -898,6 +916,7 @@ def compute_loss( model, inputs, return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument ): input_ids = inputs.pop("input_ids") lm_logits = model(input_ids).logits @@ -1005,18 +1024,32 @@ def push_to_hub(self, *args, **kwargs) -> str: return super().push_to_hub(*args, **kwargs) def tokenize_row( - self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None + self, + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, ) -> Dict: - res = super().tokenize_row(feature, model=model) - if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None: + res = super().tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None: for key in res.keys(): res[key] = res[key][1:] return res def training_step( - self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs) + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) gc.collect() torch.cuda.empty_cache() return loss @@ -1667,12 +1700,17 @@ def build(self, total_num_steps): return_tensors="pt", **data_collator_kwargs, ) + sig = inspect.signature(trainer_cls) + if "processing_class" in sig.parameters.keys(): + trainer_kwargs["processing_class"] = self.tokenizer + else: + trainer_kwargs["tokenizer"] = self.tokenizer + trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - tokenizer=self.tokenizer, data_collator=self.build_collator(training_args, **data_collator_kwargs), callbacks=self.get_callbacks(), **trainer_kwargs, @@ -1713,6 +1751,8 @@ def build_collator( ] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding + if "max_length" in kwargs: + kwargs.pop("max_length") elif use_batch_sampler_collator: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq @@ -1915,7 +1955,7 @@ def build(self, total_num_steps): dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["generate_during_eval"] = True + dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] @@ -1927,11 +1967,17 @@ def build(self, total_num_steps): trainer_cls_args = [self.model] else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + + sig = inspect.signature(trainer_cls) + if "processing_class" in sig.parameters.keys(): + dpo_trainer_kwargs["processing_class"] = self.tokenizer + else: + dpo_trainer_kwargs["tokenizer"] = self.tokenizer + dpo_trainer = trainer_cls( *trainer_cls_args, args=training_args, train_dataset=self.train_dataset, - tokenizer=self.tokenizer, callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 3d42ad17f1..c8272ac735 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -16,26 +16,6 @@ LOG = get_logger("axolotl.monkeypatch.unsloth") -ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) -""" - -PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - ) -""" - ORIGINAL_QKV_CODE = """ query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -80,12 +60,6 @@ def get_forward_code() -> str: return forward -def check_cel_is_patchable() -> bool: - forward = get_forward_code() - forward, _ = detab_code(forward) - return ORIGINAL_CEL_CODE in forward - - def get_self_attn_code() -> str: forward = inspect.getsource(LlamaFlashAttention2.forward) return forward @@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: - if model_type == "llama": - forward = get_forward_code() - LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access - forward, _ = detab_code(forward) - assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" + from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss + + def UnslothForCausalLMLoss( # pylint: disable=invalid-name + logits, + labels, + vocab_size: int, # pylint: disable=unused-argument + num_items_in_batch: int = None, + ignore_index: int = -100, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() - forward = forward.replace( - "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" - ) - forward = forward.replace( - "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", - "", - ) - forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) - forward = forward.replace( - "def forward(", - "def fast_cross_entropy_loss_forward(", - 1, + loss = fast_cross_entropy_loss( + logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch ) + return loss - # load imports necessary - import transformers.models.llama.modeling_llama - - items_to_import = [] - for item in dir(transformers.models.llama.modeling_llama): - if item in forward: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", - globals(), - ) + if model_type == "llama": + from transformers.loss import loss_utils - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) - LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 + loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment] else: raise ValueError("Unsupported model type") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4ce28d8a31..5fde4d3848 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -260,8 +260,10 @@ def terminate_handler(_, __, model_weakref): if not cfg.hub_model_id: try: - trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) - except AttributeError: + trainer.create_model_card( + model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") + ) + except (AttributeError, UnicodeDecodeError): pass elif cfg.hub_model_id: # defensively push to the hub to ensure the model card is updated diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py index 39c7abb1c1..8882742861 100644 --- a/tests/e2e/patched/test_unsloth_integration.py +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -1,22 +1,12 @@ """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" import unittest -from axolotl.monkeypatch.unsloth_ import ( - check_cel_is_patchable, - check_self_attn_is_patchable, -) +from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable class TestUnslothIntegration(unittest.TestCase): """Unsloth monkeypatch integration tests.""" - def test_is_cel_patchable(self): - # ensures the current version of transformers has loss code that matches our patching code - self.assertTrue( - check_cel_is_patchable(), - "HF transformers loss code has changed and isn't patchable", - ) - def test_is_self_attn_patchable(self): # ensures the current version of transformers has loss code that matches our patching code self.assertTrue( diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py new file mode 100644 index 0000000000..73f9e60bac --- /dev/null +++ b/tests/e2e/test_packing_loss.py @@ -0,0 +1,74 @@ +""" +E2E tests for packed training +""" + +import logging +import os +import unittest + +from tbparse import SummaryReader +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import most_recent_subdir, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPackedLlama(unittest.TestCase): + """ + Test case for Packed training of llama models + """ + + @with_temp_dir + def test_loss_packed(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM-135M", + "sequence_len": 1024, + "sample_packing": True, + "flash_attention": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) + reader = SummaryReader(event_file) + df = reader.scalars # pylint: disable=invalid-name + df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name + assert df.value.values[-1] < 2.0, "Loss is too high" From d3c45d27b54d44f354c0e4b2a7d5c515dd6be414 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Oct 2024 07:32:49 -0400 Subject: [PATCH 85/89] fix zero3 (#1994) --- deepspeed_configs/zero3_bf16.json | 9 ---- .../zero3_bf16_cpuoffload_all.json | 9 ---- .../zero3_bf16_cpuoffload_params.json | 9 ---- requirements.txt | 2 +- src/axolotl/utils/models.py | 41 ++++++++++++++++++- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json index 16e64d76b4..49fb757552 100644 --- a/deepspeed_configs/zero3_bf16.json +++ b/deepspeed_configs/zero3_bf16.json @@ -14,15 +14,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_all.json b/deepspeed_configs/zero3_bf16_cpuoffload_all.json index 09ca6785b2..3ccc66db48 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_all.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_all.json @@ -24,15 +24,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_params.json b/deepspeed_configs/zero3_bf16_cpuoffload_params.json index 41d4a21323..fe21d35f88 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_params.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_params.json @@ -20,15 +20,6 @@ "bf16": { "enabled": true }, - "fp16": { - "enabled": "auto", - "auto_cast": false, - "loss_scale": 0, - "initial_scale_power": 32, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "train_batch_size": "auto", diff --git a/requirements.txt b/requirements.txt index b6e9a554e5..6bb1aa6848 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ tokenizers>=0.20.1 bitsandbytes==0.44.1 accelerate==1.0.1 datasets==3.0.1 -deepspeed==0.14.4 +deepspeed==0.15.3 pydantic==2.6.3 addict fire diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5e53df72cb..8b433c366b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -40,7 +40,10 @@ PreTrainedTokenizerBase, ProcessorMixin, ) -from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import ( + HfTrainerDeepSpeedConfig, + is_deepspeed_zero3_enabled, +) from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.models.mamba import fix_mamba_attn_for_loss @@ -705,6 +708,38 @@ def set_attention_config(self) -> None: self.model_kwargs["low_cpu_mem_usage"] = True def build_model(self, qlora_fsdp) -> bool: + def _configure_zero3_memory_efficient_loading(): + """ + Set the deepspeed config to load the model into RAM first before moving to VRAM. + + We need to return hf_ds_cfg as it needs to exist before model loading. + """ + hf_ds_cfg = None + + if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": + hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) + hf_ds_cfg.fill_match( + "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size + ) + hf_ds_cfg.fill_match( + "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps + ) + hf_ds_cfg.fill_match( + "train_batch_size", + int(os.getenv("WORLD_SIZE", "1")) + * self.cfg.micro_batch_size + * self.cfg.gradient_accumulation_steps, + ) + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True + transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( + lambda: True + ) + + return hf_ds_cfg + skip_move_to_device = False if ( # pylint: disable=condition-evals-to-constant) (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) @@ -753,6 +788,8 @@ def build_model(self, qlora_fsdp) -> bool: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] + _ = _configure_zero3_memory_efficient_loading() + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained( @@ -846,6 +883,8 @@ def build_model(self, qlora_fsdp) -> bool: if "device_map" in self.model_kwargs: del self.model_kwargs["device_map"] + _ = _configure_zero3_memory_efficient_loading() + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained( From e1e0556c9951ef53ee627310bf3d248908fdf39a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Oct 2024 17:02:04 -0400 Subject: [PATCH 86/89] add option for resizing embeddings when adding new tokens (#2000) * add option for resizing embeddings when adding new tokens * let's just be opinonated about this setting and set it to False --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/models.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4831da3c8a..16cf312ce1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -549,6 +549,7 @@ class Config: resume_from_checkpoint: Optional[str] = None auto_resume_from_checkpoints: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None + mean_resizing_embeddings: Optional[bool] = False rl: Optional[RLType] = None reward_model: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8b433c366b..97844a5bf3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1042,7 +1042,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: hasattr(self.model, "get_input_embeddings") and self.model.get_input_embeddings().num_embeddings < embeddings_len ): - self.model.resize_token_embeddings(embeddings_len) + resize_kwargs = {} + if self.cfg.mean_resizing_embeddings is not None: + resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) else: self.model.tie_weights() From bfc77b0f3628c8df43f974873344124b8c947c26 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 29 Oct 2024 10:14:51 +0700 Subject: [PATCH 87/89] =?UTF-8?q?Feat:=20Add=20support=20for=20tokenizer?= =?UTF-8?q?=E2=80=99s=20or=20custom=20jinja=20chat=5Ftemplate=20(#1970)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow using tokenizer's default chat template with fallbacks Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests * Add tests * Fix lint and bug post merge from main * Add option `chat_template_jinja` to provide a jinja template * remove custom mistral template * Address review comments and add docs * Update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 * fix: set default to tokenizer template * Merge branch 'main' into cj_tokenizer_default_prompt_template * chore: remove redundant function * fix: re-arrange enum declaration position * fix: refactor artifact left from main merge * feat(doc): updated config with chat template options and clarified examples * chore: clarify doc * chore: added example for non-default template * chore: refactor * fix: test * fix: config being dropped and unittest to catch that * chore: lint * chore: skip duplicate * fix: rename var after merge * feat: add test for levy's dpo case * fix: remove default setting on edge case where chat template overriden in dataset section * feat: handle sharegpt deprecation better in docs * feat: add example using fallback * feat: handles chat_template requiring specific user/assistant order * fix: update test based on new defaults * fix: imported name incorrectly updated on merge * chore: lint * fix: update dummy message to prevent potential overlap with real content * fix(doc): formatting * fix: update bradleyterry to use new chat_template --------- Co-authored-by: Chirag Jain --- README.md | 2 +- docs/config.qmd | 57 ++++- docs/dataset-formats/conversation.qmd | 137 ++++++++++ src/axolotl/cli/__init__.py | 4 +- src/axolotl/core/trainer_builder.py | 4 +- .../bradley_terry/__init__.py | 2 +- .../bradley_terry/chat_template.py | 42 ++-- .../prompt_strategies/chat_template.py | 8 +- .../prompt_strategies/dpo/chat_template.py | 24 +- .../prompt_strategies/orpo/chat_template.py | 29 +-- src/axolotl/prompt_strategies/sharegpt.py | 2 +- src/axolotl/utils/chat_templates.py | 89 ++++++- src/axolotl/utils/config/__init__.py | 1 + .../config/models/input/v0_4_1/__init__.py | 129 +++++++--- src/axolotl/utils/models.py | 7 +- .../test_chat_template_utils.py | 125 +++++++++ .../prompt_strategies/test_chat_templates.py | 14 +- .../test_chat_templates_advanced.py | 26 +- .../test_dpo_chat_templates.py | 78 +++++- tests/test_validation_dataset.py | 238 ++++++++++++++++++ 20 files changed, 900 insertions(+), 118 deletions(-) create mode 100644 tests/prompt_strategies/test_chat_template_utils.py create mode 100644 tests/test_validation_dataset.py diff --git a/README.md b/README.md index 4ce7a351bb..21b954a56c 100644 --- a/README.md +++ b/README.md @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation (deprecation soon, use chat_template) + # fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template) # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - path: ... type: sharegpt diff --git a/docs/config.qmd b/docs/config.qmd index 703d587753..a7bf9080bf 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -83,7 +83,7 @@ lora_on_cpu: true datasets: # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] + # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] type: alpaca # format | format: (chat/instruct) | .load_ ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file data_files: # Optional[str] path to source data files @@ -124,6 +124,48 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: + # Using chat template + - path: ... + # Set type to `chat_template` to use this strategy + type: chat_template + # Specify the name of the chat template to use + # The name of the chat template to use for training, following values are supported: + # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. + # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py + # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. + # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. + chat_template: tokenizer_default + # Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`). + chat_template_jinja: + # The key in the data example that contains the messages. Default is "messages". + field_messages: messages + # The key in the message turn that contains the role. Default is "role". + message_field_role: role + # The key in the message turn that contains the content. Default is "content". + message_field_content: content + # Optional[Dict[str, List]]. Roles mapping for the messages. + roles: + user: ["human", "user"] + assistant: ["gpt", "assistant", "ai"] + system: ["system"] + + ## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on. + + # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss. + roles_to_train: ["gpt", "assistant"] + # Optional[str]. Which EOS tokens to train on in the conversation. Possible values are: + # - all: train on all EOS tokens + # - turn: train on the EOS token at the end of each trainable turn + # - last: train on the last EOS token in the conversation + train_on_eos: last + # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`. + message_field_training: training + # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. + # The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train). + # See example at `docs/dataset-formats/conversation.qmd` + message_field_training_detail: train_detail + + # If false, the datasets will not be shuffled and will keep their original order in `datasets`. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. shuffle_merged_datasets: true @@ -142,9 +184,16 @@ test_datasets: # use RL training: 'dpo', 'ipo', 'kto' rl: -# Saves the desired chat template to the tokenizer_config.json for easier inferencing -# Currently supports chatml and inst (mistral/mixtral) -chat_template: chatml +# The name of the chat template to use for training, following values are supported: +# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. +# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py +# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. +# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. +# The selected chat template will be saved to the tokenizer_config.json for easier inferencing +# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template. +chat_template: tokenizer_default +# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. +chat_template_jinja: null # Changes the default system message default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. # Axolotl attempts to save the dataset as an arrow after packing the data together so diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 28d13c987c..c7273c5be5 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -6,6 +6,8 @@ order: 3 ## sharegpt +UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below. + conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) ```{.json filename="data.jsonl"} @@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f ```{.json filename="data.jsonl"} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} ``` + + +## chat_template + +Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. + +```{.json filename="data.jsonl"} +{"conversations": [{"role": "...", "content": "..."}]} +``` + +See `config.qmd` for full configs and supported templates. + +### Migrating from sharegpt + +Most configs can be adapted as follows: + +```yaml +# old +chat_template: chatml +datasets: + - path: ... + type: sharegpt + conversation: chatml + +# new (if using tokenizer's chat_template) +datasets: + - path: ... + type: chat_template + + field_messages: conversations + message_field_role: from + message_field_content: value + +# new (if setting a new chat_template like chatml, gemma, etc) +chat_template: chatml +datasets: + - path: ... + type: chat_template + + field_messages: conversations + message_field_role: from + message_field_content: value +``` + +We recommend checking the below examples for other usecases. + +### Examples + +1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. + +```yaml +datasets: + - path: ... + type: chat_template +``` + +2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. + +```yaml +chat_template: gemma # this overwrites the tokenizer's chat_template +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. + +```yaml +chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +4. Using a custom jinja template on OpenAI messages format, training on all assistant messages. + +```yaml +# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty +chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}" + +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation + +For a data sample that looks like: + +```{.json filename="data.jsonl"} +{ + "conversations": [ + {"from": "system", "value": "You are an AI assistant.", "train": false}, + {"from": "human", "value": "Hello", "train": false}, + {"from": "assistant", "value": "Hello", "train": true}, + {"from": "human", "value": "How are you?", "train": true}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": false}, + {"begin_offset": 9, "end_offset": 18, "train": true}, + {"begin_offset": 19, "end_offset": 30, "train": false}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": true, + }, + {"from": "assistant", "value": "Hi there!", "train": true} + ] +} +``` + +The configuration would look like: + +```yaml +datasets: + - path: ... + type: chat_template + chat_template: tokenizer_default + field_messages: conversations + message_field_role: from + message_field_content: value + roles_to_train: [] + train_on_eos: turn + message_field_training: train + message_field_training_detail: train_detail +``` + +Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time. diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 77bb551f8c..52765a9b58 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,7 +30,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -272,7 +272,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = chat_templates(cfg.chat_template) + chat_template_str = get_chat_template(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 319ea7be59..d125f838d3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -63,7 +63,7 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1594,7 +1594,7 @@ def build(self, total_num_steps): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = chat_templates( + training_arguments_kwargs["chat_template"] = get_chat_template( self.cfg.chat_template ) diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 849d84e458..4457c50be5 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -6,7 +6,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig -LOG = logging.getLogger("axolotl.prompt_strategies") +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry") def load(strategy, tokenizer, cfg, ds_cfg): diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index ccda0a4bde..fa85cdcb26 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -2,13 +2,18 @@ Bradley-Terry model with chat template prompt strategy. """ +import logging from typing import Any, Dict, Optional from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, ) -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config + +# Configure the logger +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template") +LOG.setLevel(logging.INFO) class BTChatTemplateStrategy(ChatTemplateStrategy): @@ -27,18 +32,24 @@ def tokenize_prompt(self, prompt): # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) chosen_tokenized = super().tokenize_prompt(prompt) self.messages = "rejected_messages" # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append( + {"role": "assistant", "content": prompt["rejected"]} + ) rejected_tokenized = super().tokenize_prompt(prompt) return { @@ -53,15 +64,18 @@ def tokenize_prompt(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) prompter_params = { "tokenizer": tokenizer, - "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "chat_template": chat_template_string, + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", None ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = BTChatTemplateStrategy( diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index c7852a707f..0946a4b8c7 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -9,7 +9,7 @@ from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config # Configure the logger LOG = logging.getLogger("axolotl") @@ -405,10 +405,14 @@ def get_images(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): # pylint: disable=duplicate-code ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { "tokenizer": tokenizer, - "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "chat_template": chat_template_string, "message_field_role": ds_cfg.get("message_field_role", "role"), "message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_training": ds_cfg.get("message_field_training", None), diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index e0e5eb1294..489b864851 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -2,15 +2,16 @@ DPO prompt strategies for using tokenizer chat templates. """ -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template def default( cfg, dataset_idx=0, **kwargs ): # pylint: disable=possibly-unused-variable,unused-argument ds_cfg = cfg["datasets"][dataset_idx] - chat_template_str = chat_templates(cfg.chat_template) - + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) field_messages = ds_cfg.get("field_messages", "messages") field_chosen = ds_cfg.get("field_chosen", "chosen") field_rejected = ds_cfg.get("field_rejected", "rejected") @@ -30,6 +31,12 @@ def default( role_map[source] = target def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + messages = sample[field_messages] messages = [ { @@ -46,28 +53,29 @@ def transform_fn(sample, tokenizer=None): "role": role_map[sample[field_rejected][field_message_role]], "content": sample[field_rejected][field_message_content], } + dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} result = {} result["prompt"] = tokenizer.apply_chat_template( messages, add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) result["chosen"] = tokenizer.apply_chat_template( - [chosen], + [dummy_user_message, chosen], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) chosen_strip_index = result["chosen"].find(chosen["content"]) result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["rejected"] = tokenizer.apply_chat_template( - [rejected], + [dummy_user_message, rejected], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) rejected_strip_index = result["rejected"].find(rejected["content"]) diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index bba6948568..e53a547483 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -5,7 +5,7 @@ from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompters import Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config class Message(BaseModel): @@ -28,18 +28,13 @@ def load( """ chatml transforms for datasets with system, input, chosen, rejected """ - - chat_template = chat_templates("chatml") - if ds_cfg and "chat_template" in ds_cfg: - chat_template = ds_cfg["chat_template"] - try: - chat_template = chat_templates(chat_template) - except ValueError: - pass - tokenizer.chat_template = chat_template + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + tokenizer.chat_template = chat_template_string return ORPOTokenizingStrategy( - ORPOPrompter(chat_template, tokenizer), + ORPOPrompter(chat_template_string, tokenizer), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -248,28 +243,30 @@ def build_prompt( def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument dataset_parser = ORPODatasetParsingStrategy() - chat_template_str = chat_templates(cfg.chat_template) - def transform_fn(sample, tokenizer=None): res = {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, tokenizer=tokenizer + ) + res["prompt"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) prompt_str_len = len(res["prompt"]) res["chosen"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] res["rejected"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 4565c35d5d..069d243f52 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -62,7 +62,7 @@ def build_loader( ): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): LOG.warning( - "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", + "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template", ) conversation = ( ds_cfg["conversation"] diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 2443f56f93..dfb3fef21a 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,8 +2,19 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional -CHAT_TEMPLATES = { +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +LOG = logging.getLogger("axolotl.utils.chat_templates") + +_JINJA_TEMPALTE_CHOICE = "jinja" +_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" +_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" + +_CHAT_TEMPLATES = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... @@ -21,12 +32,18 @@ } -def chat_templates(user_choice: str): +def get_chat_template( + user_choice: str, + jinja_template: Optional[str] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +): """ - Finds the correct chat_template for the tokenizer_config. + Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. Args: user_choice (str): The user's choice of template. + jinja_template (Optional[str], optional): The jinja template string. Defaults to None. + tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None. Returns: str: The chosen template string. @@ -34,13 +51,71 @@ def chat_templates(user_choice: str): Raises: ValueError: If the user_choice is not found in the templates. """ + if user_choice == _JINJA_TEMPALTE_CHOICE: + if not jinja_template: + raise ValueError( + f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}" + ) + return jinja_template + + if user_choice == _DEFAULT_TEMPLATE_CHOICE: + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" + ) + if not tokenizer.chat_template: + raise ValueError( + f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " + f"Please add a chat_template in tokenizer config" + ) + return tokenizer.chat_template + + if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" + ) + if tokenizer.chat_template: + return tokenizer.chat_template - if user_choice in CHAT_TEMPLATES: - return CHAT_TEMPLATES[user_choice] + user_choice = user_choice[ + len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : + ] + LOG.warning( + f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." + ) + + if user_choice in _CHAT_TEMPLATES: + return _CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") +def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None): + if ds_cfg and ds_cfg.get("chat_template"): + chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = ds_cfg.get("chat_template_jinja") + else: + chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = cfg.get("chat_template_jinja") + return chat_template_choice, chat_template_jinja + + +def get_chat_template_from_config( + cfg, + ds_cfg: Optional[Dict[str, Any]] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +) -> str: + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + return get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + def register_chat_template(template_name: str, chat_template: str): """ Registers chat templates. @@ -50,7 +125,7 @@ def register_chat_template(template_name: str, chat_template: str): chat_template (str): The template string. """ - if template_name in CHAT_TEMPLATES: + if template_name in _CHAT_TEMPLATES: raise ValueError(f"Template '{template_name}' already exists.") - CHAT_TEMPLATES[template_name] = chat_template + _CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index f732db06fc..afc8c4fc41 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" ) cfg.datasets[idx].chat_template = cfg.chat_template + cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 16cf312ce1..96e5330005 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -8,9 +8,16 @@ import os from enum import Enum from importlib.metadata import version -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union -from pydantic import BaseModel, Field, conlist, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + StringConstraints, + conlist, + field_validator, + model_validator, +) from transformers import SchedulerType from transformers.training_args import OptimizerNames @@ -21,6 +28,37 @@ SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} +class RLType(str, Enum): + """RL trainer type configuration subset""" + + dpo = "dpo" # pylint: disable=invalid-name + ipo = "ipo" # pylint: disable=invalid-name + orpo = "orpo" # pylint: disable=invalid-name + kto = "kto" # pylint: disable=invalid-name + simpo = "simpo" # pylint: disable=invalid-name + + +class ChatTemplate(str, Enum): + """Chat templates configuration subset""" + + alpaca = "alpaca" # pylint: disable=invalid-name + chatml = "chatml" # pylint: disable=invalid-name + mistral_v1 = "mistral_v1" # pylint: disable=invalid-name + mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name + mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name + gemma = "gemma" # pylint: disable=invalid-name + cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name + phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name + deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name + jamba = "jamba" # pylint: disable=invalid-name + jinja = "jinja" # pylint: disable=invalid-name + qwen_25 = "qwen_25" # pylint: disable=invalid-name + tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name + + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -105,13 +143,19 @@ class SFTDataset(BaseModel): input_transform: Optional[str] = None shards: Optional[int] = None conversation: Optional[str] = None - chat_template: Optional[str] = None + # Do not make this too strict or it will break the validator to choose different dataset class + chat_template: Optional[ + Union[ + ChatTemplate, + str, + ] + ] = None + chat_template_jinja: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None input_format: Optional[str] = None name: Optional[str] = None ds_type: Optional[str] = None train_on_split: Optional[str] = None - field: Optional[str] = None field_human: Optional[str] = None field_model: Optional[str] = None @@ -122,13 +166,32 @@ class SFTDataset(BaseModel): message_field_training_detail: Optional[str] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None - roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None - trust_remote_code: Optional[bool] = False revision: Optional[str] = None + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # Set chat_template to tokenizer_default if not set + if data.get("type") == "chat_template" and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.tokenizer_default + + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -174,35 +237,6 @@ class KTODataset(BaseModel): revision: Optional[str] = None -class RLType(str, Enum): - """RL trainer type configuration subset""" - - dpo = "dpo" # pylint: disable=invalid-name - ipo = "ipo" # pylint: disable=invalid-name - orpo = "orpo" # pylint: disable=invalid-name - kto = "kto" # pylint: disable=invalid-name - simpo = "simpo" # pylint: disable=invalid-name - - -class ChatTemplate(str, Enum): - """Chat templates configuration subset""" - - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - mistral_v1 = "mistral_v1" # pylint: disable=invalid-name - mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name - mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - phi_35 = "phi_35" # pylint: disable=invalid-name - deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name - jamba = "jamba" # pylint: disable=invalid-name - qwen_25 = "qwen_25" # pylint: disable=invalid-name - - class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -719,7 +753,13 @@ class Config: gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None - chat_template: Optional[ChatTemplate] = None + chat_template: Optional[ + Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] + ] = None + chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None fix_untrained_tokens: Optional[bool] = None @@ -828,6 +868,23 @@ def check_sample_packing_w_xformers(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 97844a5bf3..f3386cccfa 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -53,7 +53,7 @@ ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import zero_only from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper @@ -296,7 +296,10 @@ def load_tokenizer(cfg): LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: - chat_template_string = chat_templates(cfg.chat_template) + chat_template_string = get_chat_template_from_config( + cfg=cfg, + tokenizer=tokenizer, + ) if cfg.default_system_message and cfg.chat_template == "chatml": chat_template_string = chat_template_string.replace( "You are a helpful assistant.", cfg.default_system_message diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py new file mode 100644 index 0000000000..b63c9aa179 --- /dev/null +++ b/tests/prompt_strategies/test_chat_template_utils.py @@ -0,0 +1,125 @@ +""" +Tests for utils in axolotl.utils.chat_templates +""" +import unittest + +import pytest +from transformers import AutoTokenizer + +from axolotl.utils.chat_templates import ( + _CHAT_TEMPLATES, + extract_chat_template_args, + get_chat_template, +) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + + return tokenizer + + +class TestGetChatTemplateUtils: + """ + Tests the get_chat_template function. + """ + + def test_known_chat_template(self): + chat_template_str = get_chat_template("llama3") + assert chat_template_str == _CHAT_TEMPLATES["llama3"] + + def test_invalid_chat_template(self): + with pytest.raises(ValueError) as exc: + get_chat_template("invalid_template") + assert str(exc) == "Template 'invalid_template' not found." + + def test_tokenizer_default_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=None) + + def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer) + + def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_tokenizer_default_fallback_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default_fallback_test", tokenizer=None) + + def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == get_chat_template("chatml") + + def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_jinja_template_mode(self): + jinja_template = "example_jinja_template" + chat_template_str = get_chat_template("jinja", jinja_template=jinja_template) + assert chat_template_str == jinja_template + + def test_jinja_template_mode_no_jinja_template(self): + with pytest.raises(ValueError): + get_chat_template("jinja", jinja_template=None) + + def test_extract_chat_template_args(self): + # No ds_cfg + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml"}, + ) + assert chat_template_choice == "chatml" + assert chat_template_jinja is None + + # ds_cfg provided + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": "llama3", "chat_template_jinja": None}, + ) + assert chat_template_choice == "llama3" + assert chat_template_jinja is None + + # ds_cfg provided with jinja template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml", "chat_template_jinja": None}, + ds_cfg={ + "chat_template": "jinja", + "chat_template_jinja": "ds_jinja_template", + }, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "ds_jinja_template" + + # ds_cfg provided with no chat_template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"}, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "global_jinja_template" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 20533504ce..4ec12b82cb 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -11,7 +11,7 @@ load, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault logging.basicConfig(level=logging.DEBUG) @@ -73,7 +73,7 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -113,7 +113,7 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_template=chat_templates("phi_35"), + chat_template=get_chat_template("phi_35"), message_field_role="role", message_field_content="content", roles={ @@ -171,7 +171,7 @@ def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -230,7 +230,7 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -283,7 +283,7 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -336,7 +336,7 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 50429e3a26..be8e3ccdf9 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -12,7 +12,7 @@ ChatTemplateStrategy, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") @@ -35,7 +35,7 @@ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -80,7 +80,7 @@ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -123,7 +123,7 @@ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -151,7 +151,7 @@ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -184,7 +184,7 @@ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -205,7 +205,7 @@ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -232,7 +232,7 @@ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -282,7 +282,7 @@ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -315,7 +315,7 @@ def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -343,7 +343,7 @@ def test_drop_system_message(self, llama3_tokenizer, basic_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), drop_system_message=True, ), tokenizer=llama3_tokenizer, @@ -371,7 +371,7 @@ def test_custom_roles(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), roles=custom_roles, ), tokenizer=llama3_tokenizer, @@ -424,7 +424,7 @@ def test_message_field_training(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_training="train", message_field_training_detail="train_detail", ), diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index cca48b1cf3..740edc22f2 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -86,6 +86,20 @@ def fixture_llama3_tokenizer(): return tokenizer +@pytest.fixture(name="phi3_tokenizer") +def fixture_phi3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct") + + return tokenizer + + +@pytest.fixture(name="gemma_tokenizer") +def fixture_gemma_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a") + + return tokenizer + + class TestAssistantDPOChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -99,7 +113,7 @@ def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", } ], } @@ -124,7 +138,7 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", "field_messages": "conversation", "field_chosen": "better", "field_rejected": "worse", @@ -152,5 +166,65 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): assert result["rejected"] == "party on<|eot_id|>" +class TestAssistantDPOChatTemplatePhi3: + """ + Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy. + """ + + def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == ( + "<|user|>\nhello<|end|>\n" + + "<|assistant|>\nhello<|end|>\n" + + "<|user|>\ngoodbye<|end|>\n" + + "<|assistant|>\n" + ) + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + +class TestAssistantDPOChatTemplateGemma: + """ + Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy. + """ + + def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer) + assert result["prompt"] == ( + "user\nhello\n" + + "model\nhello\n" + + "user\ngoodbye\n" + + "model\n" + ) + assert result["chosen"] == "goodbye" + assert result["rejected"] == "party on" + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py new file mode 100644 index 0000000000..389424217b --- /dev/null +++ b/tests/test_validation_dataset.py @@ -0,0 +1,238 @@ +"""Module for testing the validation module for the dataset config""" + +import warnings +from typing import Optional + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate +from axolotl.utils.dict import DictDefault + +warnings.filterwarnings("error") + + +@pytest.fixture(name="minimal_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + +# pylint: disable=too-many-public-methods (duplicate-code) +class BaseValidation: + """ + Base validation module to setup the log capture + """ + + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + +class TestValidationCheckDatasetConfig(BaseValidation): + """ + Test the validation for the dataset config to ensure no correct parameters are dropped + """ + + def test_dataset_config_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "sharegpt", + "conversation": "chatml", + "shards": 10, + } + ] + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template is None + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == ChatTemplate.chatml + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "chat_template": "gemma", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == cfg.chat_template + assert ( + checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() From 107b67b852badb4d269e8100e76b544ac220d4aa Mon Sep 17 00:00:00 2001 From: Oliver Kunc <36070570+OliverKunc@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:13:50 +0100 Subject: [PATCH 88/89] Hardware requirements (#1997) [skip ci] * Hardware requirements https://github.com/axolotl-ai-cloud/axolotl/issues/1992 * Update README.md --------- Co-authored-by: Wing Lian --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 21b954a56c..c12aa3bba0 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ Features: Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task. -**Requirements**: Python >=3.10 and Pytorch >=2.1.1. +**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1. ```bash git clone https://github.com/axolotl-ai-cloud/axolotl From 8c3a727f9d60ffd3af385f90bcc3fa3a56398fe1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 29 Oct 2024 21:26:03 +0700 Subject: [PATCH 89/89] feat: update yml chat_template to specify dataset field (#2001) [skip ci] * feat: update yml chat_template to specify dataset field * feat: replace sharegpt references with chat_template --- .../{dev_sharegpt.yml => dev_chat_template.yml} | 4 ++-- docs/debugging.qmd | 14 +++++++------- examples/deepseek-v2/qlora-fsdp-2_5.yaml | 5 ++++- examples/gemma2/qlora.yml | 5 ++++- examples/jamba/qlora_fsdp_large.yaml | 6 +++++- examples/llama-3/fft-8b-liger-fsdp.yaml | 4 ++++ examples/phi/lora-3.5.yaml | 1 - 7 files changed, 26 insertions(+), 13 deletions(-) rename devtools/{dev_sharegpt.yml => dev_chat_template.yml} (92%) diff --git a/devtools/dev_sharegpt.yml b/devtools/dev_chat_template.yml similarity index 92% rename from devtools/dev_sharegpt.yml rename to devtools/dev_chat_template.yml index 9c65b49dcd..9697da4b33 100644 --- a/devtools/dev_sharegpt.yml +++ b/devtools/dev_chat_template.yml @@ -7,8 +7,8 @@ load_in_8bit: true load_in_4bit: false datasets: - - path: philschmid/guanaco-sharegpt-style - type: sharegpt + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template shards: 10 val_set_size: 0 output_dir: temp_debug/axolotl_outputs/model diff --git a/docs/debugging.qmd b/docs/debugging.qmd index 1d0779b073..029549d85b 100644 --- a/docs/debugging.qmd +++ b/docs/debugging.qmd @@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible. ### Background -The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config: +The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config: ```yaml datasets: - - path: # example on HF Hub: philschmid/guanaco-sharegpt-style - type: sharegpt + - path: # example on HF Hub: fozziethebeat/alpaca_messages_2k_test + type: chat_template ``` >[!Important] @@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely. The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs. -For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch. +For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch. ```jsonc // .vscode/launch.json @@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler "version": "0.2.0", "configurations": [ { - "name": "Debug axolotl prompt - sharegpt", + "name": "Debug axolotl prompt - chat_template", "type": "python", "module": "accelerate.commands.launch", "request": "launch", "args": [ - "-m", "axolotl.cli.train", "dev_sharegpt.yml", + "-m", "axolotl.cli.train", "dev_chat_template.yml", // The flags below simplify debugging by overriding the axolotl config // with the debugging tips above. Modify as needed. "--dataset_processes=1", // limits data preprocessing to one process @@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
-[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing. +[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing. [^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html). diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 6e82062d66..0320e02138 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -16,7 +16,10 @@ chat_template: deepseek_v2 datasets: - path: mlabonne/FineTome-100k type: chat_template - split: train + split: train[:20%] + field_messages: conversations + message_field_role: from + message_field_content: value dataset_prepared_path: last_run_prepared val_set_size: 0.0 diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index b6dd653750..00e6d84e0d 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -11,8 +11,11 @@ chat_template: gemma datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template - chat_template: gemma drop_system_message: true + field_messages: conversations + message_field_role: from + message_field_content: value + val_set_size: 0.0 output_dir: ./outputs/out diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 28316efd57..84cf906422 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer load_in_4bit: true strict: false use_tensorboard: true +chat_template: jamba datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template - chat_template: jamba drop_system_message: true + field_messages: conversations + message_field_role: from + message_field_content: value + dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: jamba-large-fsdp-qlora-ft diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index e84d221f85..99ba63fcc6 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -14,6 +14,10 @@ datasets: - path: mlabonne/FineTome-100k type: chat_template split: train[:20%] + field_messages: conversations + message_field_role: from + message_field_content: value + dataset_prepared_path: last_run_prepared val_set_size: 0.02 output_dir: ./outputs/out diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml index 59d667b8db..246701148c 100644 --- a/examples/phi/lora-3.5.yaml +++ b/examples/phi/lora-3.5.yaml @@ -10,7 +10,6 @@ chat_template: phi_3 datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - chat_template: phi_3 field_messages: messages message_field_role: role message_field_content: content