diff --git a/.gitignore b/.gitignore index e6dfee67db..d15f2000f3 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,9 @@ qlora-out/* mlruns/* /.quarto/ +prepared-datasets/ +submit.sh +*.out* + +typings/ +out/ diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml new file mode 100644 index 0000000000..18db9b8b78 --- /dev/null +++ b/examples/phi/phi3-ft.yml @@ -0,0 +1,64 @@ +base_model: microsoft/Phi-3-mini-4k-instruct +trust_remote_code: true +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +chat_template: phi_3 + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: garage-bAInd/Open-Platypus + type: alpaca:phi + +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 64 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +gradient_accumulation_steps: 1 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch +adam_beta2: 0.95 +adam_epsilon: 0.00001 +max_grad_norm: 1.0 +lr_scheduler: cosine +learning_rate: 5.0e-6 + +train_on_inputs: false +group_by_length: false +bf16: auto + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: True +early_stopping_patience: 3 +logging_steps: 1 +flash_attention: true + +eval_steps: 1000 +save_steps: 5000 +eval_table_size: 2 +eval_batch_size: 2 +eval_sample_packing: false +eval_max_new_tokens: 32 +eval_causal_lm_metrics: ["perplexity"] +do_causal_lm_eval: true + +warmup_ratio: 0.2 +debug: true +weight_decay: 0.1 +resize_token_embeddings_to_32x: true diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 60ea5c99f9..102e9e53be 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -20,6 +20,7 @@ class PromptStyle(Enum): INSTRUCT = "instruct" CHAT = "chat" CHATML = "chatml" + PHI = "phi" class Prompter: @@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter): system_format: str = "{system}" turn_format: str turn_no_input_format: str - prompt_style: Optional[PromptStyle] = None + prompt_style: Optional[str] = None - def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): + def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value): self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value self.match_prompt_style() @@ -52,16 +53,20 @@ def match_prompt_style(self): "### Instruction:\n{instruction}\n\n### Response:\n" ) self.system_format = "{system}\n\n" - if self.prompt_style == PromptStyle.CHAT.value: + elif self.prompt_style == PromptStyle.CHAT.value: self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.system_format = "SYSTEM: {system}\n" - if self.prompt_style == PromptStyle.CHATML.value: + elif self.prompt_style == PromptStyle.CHATML.value: self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" self.turn_no_input_format = ( "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" ) self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" + elif self.prompt_style == PromptStyle.PHI.value: + self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>" + self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>" + self.system_format = "<|system|>{system}\n" def _build_result(self, instruction, input_text, output): # returns the full prompt from instruction and optional input @@ -381,12 +386,14 @@ def __init__( conversation: Optional[Union[str, Conversation]] = None, role_key_human: Optional[str] = None, role_key_model: Optional[str] = None, + role_key_tool: Optional[str] = None, roles: Optional[dict] = None, ): super().__init__( conversation=conversation, role_key_human=role_key_human, role_key_model=role_key_model, + role_key_tool=role_key_tool, roles=roles, ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index c21ef0ad7a..73715b06ab 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -5,6 +5,7 @@ import logging import math import os +import traceback from shutil import copyfile from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Dict, List @@ -30,6 +31,7 @@ from axolotl.utils import is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( barrier, @@ -374,10 +376,14 @@ def __init__(self, cfg): def __maybe_load_metrics(self): metrics = {} for metric in self.cfg.eval_causal_lm_metrics: - try: - metrics[metric] = evaluate.load(metric) - except Exception as exc: # pylint: disable=broad-exception-caught - LOG.warning(f"{metric}: {exc.args}") + if metric == "perplexity": + max_seq_len = self.cfg.eval_max_new_tokens + metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len) + else: + try: + metrics[metric] = evaluate.load(metric) + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.warning(f"{metric}: {exc.args}") return metrics def on_evaluate( @@ -421,13 +427,20 @@ def compute(metric: evaluate.Metric, **kwargs): # safely compute a metric and return the score if the format is correct metric_score = None try: - metric_score = metric.compute(**kwargs) + # Only pass the kwargs that are in the metric's feature list + metric_kwargs = { + k: kwargs[k] + for k in metric._feature_names() # pylint: disable=protected-access + if k in kwargs + } + metric_score = metric.compute(**metric_kwargs) return ( metric_score["score"] if "score" in metric_score else metric_score["mean_score"] ) except Exception: # pylint: disable=broad-exception-caught + traceback.print_exc() LOG.debug( f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}" ) @@ -443,11 +456,12 @@ def evaluate_preds(sources, predictions, references): predictions=predictions, sources=sources, ) - score = score or compute( - metric, - references=[[r] for r in references], - predictions=predictions, - ) + if score is None: + score = compute( + metric, + references=[[r] for r in references], + predictions=predictions, + ) scores[metric_name] = score return scores diff --git a/src/axolotl/utils/callbacks/perplexity.py b/src/axolotl/utils/callbacks/perplexity.py new file mode 100644 index 0000000000..2e64176812 --- /dev/null +++ b/src/axolotl/utils/callbacks/perplexity.py @@ -0,0 +1,76 @@ +"""callback to calculate perplexity as an evaluation metric.""" +from typing import Dict, List, Optional + +import torch +from torch import Tensor +from tqdm import tqdm +from transformers.modeling_outputs import CausalLMOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils import PreTrainedTokenizer + + +class Perplexity: + """ + Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity. + This is a custom variant that doesn't re-tokenize the input or re-load the model. + """ + + def __init__( + self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + max_seq_len: int, + stride: int = 512, + ) -> None: + self.max_seq_len = max_seq_len + self.stride = stride + self.model = model + self.tokenizer = tokenizer + self.device = model.device + self.name = "perplexity" + + def _feature_names(self) -> List[str]: + return ["references"] + + def compute( + self, + references: Optional[List[str]] = None, + ) -> Dict[str, float]: + """ + Compute perplexity in a fixed length sliding window across the sequence. + """ + assert references is not None, "Missing parameter: references" + + references_tokenized = self.tokenizer( + references, return_tensors="pt", padding=True, truncation=True + ) + input_ids: Tensor = references_tokenized["input_ids"] # type: ignore + input_ids = input_ids.to(self.device) + + sequence_length = input_ids.size(1) + + losses = [] + prev_end_loc = 0 + for begin_loc in tqdm(range(0, sequence_length, self.stride)): + end_loc = min(begin_loc + self.max_seq_len, sequence_length) + trg_len = end_loc - prev_end_loc + input_ids_slice = input_ids[:, begin_loc:end_loc] + labels_slice = input_ids_slice.clone() + labels_slice[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs: CausalLMOutput = self.model( + input_ids=input_ids_slice, labels=labels_slice + ) + + losses.append(outputs.loss) + + prev_end_loc = end_loc + if end_loc == sequence_length: + break + + perplexity = torch.exp(torch.stack(losses).mean()).item() + + return { + "score": perplexity, + } diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 1fe888aa80..725934cf56 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -25,6 +25,7 @@ def chat_templates(user_choice: str): "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", } if user_choice in templates: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index ad551c74af..ed165e89ca 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -10,6 +10,7 @@ from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import ( + SUPPORTED_METRICS, AxolotlConfigWCapabilities, AxolotlInputConfig, ) @@ -586,13 +587,12 @@ def legacy_validate_config(cfg): ) if cfg.eval_causal_lm_metrics: - supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] if not isinstance(cfg.eval_causal_lm_metrics, list): raise ValueError("eval_causal_lm_metrics must be a list") # only ["sacrebleu", "comet", "ter", "chrf"] supported - if set(cfg.eval_causal_lm_metrics) - set(supported_metrics): + if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS: raise ValueError( - f"eval_causal_lm_metrics must be one of {supported_metrics}" + f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" ) # TODO diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f363ebfdce..1af3608535 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -17,6 +17,8 @@ LOG = logging.getLogger("axolotl.utils.config.models.input") +SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -176,6 +178,7 @@ class ChatTemplate(str, Enum): gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name + phi_3 = "phi_3" # pylint: disable=invalid-name class LoftQConfig(BaseModel): @@ -1072,13 +1075,12 @@ def check_causal_lm_evals(cls, data): ) if data.get("eval_causal_lm_metrics"): - supported_metrics = ["sacrebleu", "comet", "ter", "chrf"] if not isinstance(data.get("eval_causal_lm_metrics"), list): raise ValueError("eval_causal_lm_metrics must be a list") # only ["sacrebleu", "comet", "ter", "chrf"] supported - if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics): + if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: raise ValueError( - f"eval_causal_lm_metrics must be one of {supported_metrics}" + f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" ) return data diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index b3e754bc03..bbea1987f1 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -474,12 +474,16 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - if split == "train" and cfg.val_set_size: + val_set_size = ( + int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + ) + + if split == "train" and val_set_size: # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access + "|" - + str(cfg.val_set_size) + + str(val_set_size) + "|" + "train" + "|" @@ -488,7 +492,7 @@ def load_prepare_datasets( to_hash_test = ( dataset._fingerprint # pylint: disable=protected-access + "|" - + str(cfg.val_set_size) + + str(val_set_size) + "|" + "test" + "|" @@ -498,9 +502,7 @@ def load_prepare_datasets( test_fingerprint = md5(to_hash_test) dataset = dataset.train_test_split( - test_size=int(cfg.val_set_size) - if cfg.val_set_size == int(cfg.val_set_size) - else cfg.val_set_size, + test_size=val_set_size, shuffle=False, seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint, @@ -535,6 +537,10 @@ def get_dataset_wrapper( "keep_in_memory": cfg.dataset_keep_in_memory is True, } + LOG.info( + f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + ) + if ( isinstance(dataset, Dataset) and "input_ids" in dataset.features diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py new file mode 100644 index 0000000000..e66e95d0cd --- /dev/null +++ b/tests/test_perplexity.py @@ -0,0 +1,41 @@ +"""unit tests for perplexity eval callback""" +# pylint: disable=redefined-outer-name + +from pytest import fixture +from transformers.models.auto.modeling_auto import AutoModelForCausalLM +from transformers.models.auto.tokenization_auto import AutoTokenizer + +from axolotl.utils.callbacks.perplexity import Perplexity + +MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + +@fixture() +def metric(tokenizer): + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) + + return Perplexity(model, tokenizer, 512) + + +@fixture() +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + + +def test_perplexity_longer_than_stride(metric): + # taken from https://huggingface.co/datasets/roneneldan/TinyStories + sample_text = """ +Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after. +One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends. +""" + result = metric.compute([sample_text]) + ppl = result["score"] + assert round(ppl, 2) == 5.37 + + +def test_perplexity_short(metric): + # taken from https://huggingface.co/datasets/roneneldan/TinyStories + sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun." + result = metric.compute([sample_text]) + ppl = result["score"] + assert round(ppl, 2) == 10.02