Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi-3 conversation format, example training script and perplexity metric #1582

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,9 @@ qlora-out/*
mlruns/*

/.quarto/
prepared-datasets/
submit.sh
*.out*

typings/
out/
64 changes: 64 additions & 0 deletions examples/phi/phi3-ft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
base_model: microsoft/Phi-3-mini-4k-instruct
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
chat_template: phi_3

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca:phi

dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 5.0e-6

train_on_inputs: false
group_by_length: false
bf16: auto

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: True
early_stopping_patience: 3
logging_steps: 1
flash_attention: true

eval_steps: 1000
save_steps: 5000
eval_table_size: 2
eval_batch_size: 2
eval_sample_packing: false
eval_max_new_tokens: 32
eval_causal_lm_metrics: ["perplexity"]
do_causal_lm_eval: true

warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
15 changes: 11 additions & 4 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
PHI = "phi"


class Prompter:
Expand All @@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
prompt_style: Optional[str] = None

def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()

Expand All @@ -52,16 +53,20 @@ def match_prompt_style(self):
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
elif self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
elif self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
self.system_format = "<|system|>{system}\n"

def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
Expand Down Expand Up @@ -381,12 +386,14 @@ def __init__(
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
role_key_tool=role_key_tool,
roles=roles,
)

Expand Down
34 changes: 24 additions & 10 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import math
import os
import traceback
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List
Expand All @@ -30,6 +31,7 @@

from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
Expand Down Expand Up @@ -374,10 +376,14 @@ def __init__(self, cfg):
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
else:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics

def on_evaluate(
Expand Down Expand Up @@ -421,13 +427,20 @@ def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
# Only pass the kwargs that are in the metric's feature list
metric_kwargs = {
k: kwargs[k]
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}
metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
traceback.print_exc()
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
Expand All @@ -443,11 +456,12 @@ def evaluate_preds(sources, predictions, references):
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
if score is None:
score = compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores

Expand Down
76 changes: 76 additions & 0 deletions src/axolotl/utils/callbacks/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional

import torch
from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer


class Perplexity:
"""
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
This is a custom variant that doesn't re-tokenize the input or re-load the model.
"""

def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"

def _feature_names(self) -> List[str]:
return ["references"]

def compute(
self,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
Compute perplexity in a fixed length sliding window across the sequence.
"""
assert references is not None, "Missing parameter: references"

references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device)

sequence_length = input_ids.size(1)

losses = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
labels_slice = input_ids_slice.clone()
labels_slice[:, :-trg_len] = -100

with torch.no_grad():
outputs: CausalLMOutput = self.model(
input_ids=input_ids_slice, labels=labels_slice
)

losses.append(outputs.loss)

prev_end_loc = end_loc
if end_loc == sequence_length:
break

perplexity = torch.exp(torch.stack(losses).mean()).item()

return {
"score": perplexity,
}
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
}

if user_choice in templates:
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
SUPPORTED_METRICS,
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
Expand Down Expand Up @@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
)

if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)

# TODO
Expand Down
8 changes: 5 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

LOG = logging.getLogger("axolotl.utils.config.models.input")

SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}


class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
Expand Down Expand Up @@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name


class LoftQConfig(BaseModel):
Expand Down Expand Up @@ -1072,13 +1075,12 @@ def check_causal_lm_evals(cls, data):
)

if data.get("eval_causal_lm_metrics"):
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
return data

Expand Down
18 changes: 12 additions & 6 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,16 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)

if split == "train" and cfg.val_set_size:
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)

if split == "train" and val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "train"
+ "|"
Expand All @@ -488,7 +492,7 @@ def load_prepare_datasets(
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ str(val_set_size)
+ "|"
+ "test"
+ "|"
Expand All @@ -498,9 +502,7 @@ def load_prepare_datasets(
test_fingerprint = md5(to_hash_test)

dataset = dataset.train_test_split(
test_size=int(cfg.val_set_size)
if cfg.val_set_size == int(cfg.val_set_size)
else cfg.val_set_size,
test_size=val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
Expand Down Expand Up @@ -535,6 +537,10 @@ def get_dataset_wrapper(
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)

if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
Expand Down
Loading
Loading