Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc improvements #5

Closed
wants to merge 9 commits into from
9 changes: 7 additions & 2 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build_prompt(self, conversation, add_generation_prompt=False):

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
truncation=False,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
Expand Down Expand Up @@ -338,10 +338,15 @@ def get_conversation_thread(self, prompt):

def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")

prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"chat_template": chat_template_str,
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/prompt_strategies/orpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def load(
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
chat_template = chat_templates(chat_template, tokenizer=tokenizer)
except ValueError:
pass
tokenizer.chat_template = chat_template
Expand Down Expand Up @@ -248,11 +248,11 @@ def build_prompt(
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()

chat_template_str = chat_templates(cfg.chat_template)

def transform_fn(sample, tokenizer=None):
res = {}

chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer)

res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
if cfg.local_rank == 0 and cfg.save_model_on_interrupt:

def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
Expand Down Expand Up @@ -248,6 +248,11 @@ def terminate_handler(_, __, model_weakref):
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

if cfg.deepspeed:
trainer.deepspeed.destroy()
trainer.accelerator.free_memory()
trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None

return model, tokenizer


Expand Down
36 changes: 35 additions & 1 deletion src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
import logging

LOG = logging.getLogger("axolotl.utils.chat_templates")

def chat_templates(user_choice: str):
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"


def chat_templates(user_choice: str, tokenizer=None):
"""
Finds the correct chat_template for the tokenizer_config.

Expand All @@ -21,6 +27,7 @@ def chat_templates(user_choice: str):
templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
Expand All @@ -29,6 +36,33 @@ def chat_templates(user_choice: str):
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}",
}

if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template

if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template

user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}"
)

if user_choice in templates:
return templates[user_choice]

Expand Down
23 changes: 20 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union

from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames

Expand Down Expand Up @@ -190,6 +197,7 @@ class ChatTemplate(str, Enum):
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name


class LoftQConfig(BaseModel):
Expand Down Expand Up @@ -673,10 +681,19 @@ class Config:
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None

chat_template: Optional[ChatTemplate] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
default_system_message: Optional[str] = None

fix_untrained_tokens: Optional[bool] = None
# Added by TrueFoundry Team
save_model_on_interrupt: bool = True
drop_long_sequences: bool = True
############################

# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
Expand Down
17 changes: 17 additions & 0 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def prepare_dataset(cfg, tokenizer):
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if len(train_dataset) == 0:
raise ValueError(
"No samples left in train data after loading and processing. "
)

if (
eval_dataset is not None
and (cfg.val_set_size or cfg.test_datasets)
and len(eval_dataset) == 0
):
raise ValueError(
"No samples left in eval data after loading and processing. "
)
else:
path = cfg.pretraining_dataset
split = "train"
Expand Down Expand Up @@ -99,6 +112,10 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if len(train_dataset) == 0:
raise ValueError(
"No samples left in train data after loading and processing. "
)
return train_dataset, eval_dataset, cfg.max_steps, prompters

if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")

if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template)
chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset):
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths


def plot_ascii_lengths_histogram(data, title, logger):
max_value = max(data)
bucket_width = 512
bins = np.arange(0, max_value + bucket_width, bucket_width)
histogram, _ = np.histogram(data, bins=bins)
top = " ".join(("-" * 10, title, "-" * 10))
bottom = "-" * len(top)
logger.info(top)
scale_factor = 40 / max(histogram)
for i, value in enumerate(histogram):
lower_bound = i * bucket_width
upper_bound = (i + 1) * bucket_width - 1
if value:
hist_bar = "□" * int(value * scale_factor)
else:
hist_bar = "x"
logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})")
logger.info(bottom)
logger.info("\n")
71 changes: 54 additions & 17 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.samplers.utils import plot_ascii_lengths_histogram

LOG = get_logger("axolotl")

Expand Down Expand Up @@ -170,26 +171,62 @@ def add_length(sample):
return sample


def drop_no_outputs(sample):
return any(v != -100 for v in sample["labels"])


def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def _maybe_drop_sequences(cfg, dataset, ds_split_name: str):
_ds_lens = get_dataset_lengths(dataset)
plot_ascii_lengths_histogram(
data=_ds_lens, title=f"{ds_split_name} Dataset Lengths", logger=LOG
)
min_len, max_len = np.min(_ds_lens), np.max(_ds_lens)
LOG.debug(f"min_input_len: {min_len}", main_process_only=True)
LOG.debug(f"max_input_len: {max_len}", main_process_only=True)
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
len_pre_drop = len(dataset)
dataset = dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc=f"Dropping Long Sequences From {ds_split_name} Dataset",
)
dropped_rows = len_pre_drop - len(dataset)
if dropped_rows > 0:
LOG.warning(f"Dropped {dropped_rows} rows from {ds_split_name} dataset")
if not cfg.drop_long_sequences:
raise ValueError(
f"Found {dropped_rows} sequences longer than {cfg.sequence_len} tokens in {ds_split_name} Dataset. "
f"Longest sequence is {max_len} tokens. "
f"Please either increase --sequence_len or set --drop_long_sequences to True to drop and ignore such sequences."
)
len_pre_drop = len(dataset)
dataset = dataset.filter(
drop_no_outputs,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Sequences Without Outputs",
)
dropped_rows = len_pre_drop - len(dataset)
if dropped_rows > 0:
LOG.warning(
f"Dropped {dropped_rows} rows with no outputs from {ds_split_name} Dataset"
)
return dataset

if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
Expand All @@ -203,18 +240,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
train_dataset = _maybe_drop_sequences(
cfg=cfg, dataset=train_dataset, ds_split_name="Train"
)

if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
eval_dataset = _maybe_drop_sequences(
cfg=cfg, dataset=eval_dataset, ds_split_name="Eval"
)

if cfg.group_by_length:
Expand Down Expand Up @@ -274,10 +306,15 @@ def process_pretraining_datasets_for_packing(
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)

_len_pre_drop = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
desc="Dropping Long Sequences From Train Dataset",
)
_dropped_rows = _len_pre_drop - len(train_dataset)
if _dropped_rows > 0:
LOG.warning(f"Dropped {_dropped_rows} rows")

if skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,
Expand Down
Loading