Skip to content

Commit

Permalink
Misc improvements
Browse files Browse the repository at this point in the history
Enable callbacks injection from plugins

Fix misc issues with axolotl plugins

Fix remote code checking

Enable loss average across devices

Add seq len validation

Enhance sequence lens validation

Remove legacy code for patching _get_unpad_data

Add pre truncation token counting for completion

Fix plugin callbacks duplication
  • Loading branch information
chiragjn committed Nov 19, 2024
1 parent db51a9e commit 0011a39
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 40 deletions.
34 changes: 15 additions & 19 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,11 +1212,17 @@ def get_post_trainer_create_callbacks(self, trainer):
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []

plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks

def hook_pre_create_training_args(self, training_arguments_kwargs):
Expand Down Expand Up @@ -1263,7 +1269,7 @@ def get_callbacks(self):
return callbacks

def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
Expand Down Expand Up @@ -1301,17 +1307,7 @@ def get_post_trainer_create_callbacks(self, trainer):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))

if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks

def _get_trainer_cls(self):
Expand Down Expand Up @@ -1340,7 +1336,7 @@ def build(self, total_num_steps):
else max(min(int(0.005 * total_num_steps), 10), 1)
)

training_arguments_kwargs = {}
training_arguments_kwargs = {"average_tokens_across_devices": True}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
Expand Down Expand Up @@ -1847,7 +1843,7 @@ def get_post_trainer_create_callbacks(self, trainer):
return callbacks

def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
training_args_kwargs = {"average_tokens_across_devices": True}
for arg in [
"adam_beta1",
"adam_beta2",
Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ def format(self, record):
"filters": [],
"stream": sys.stdout,
},
"file": {
"class": "logging.FileHandler",
"formatter": "simple",
"filename": "train.log",
"mode": "w",
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"handlers": ["color_console", "file"],
"level": "DEBUG",
"propagate": False,
},
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/prompt_strategies/alpaca_w_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def tokenize_prompt(self, prompt):
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

if "num_tokens_pre_truncation" in tokenized_prompt:
tokenized_prompt["num_tokens_pre_truncation"] = (
tokenized_prompt["num_tokens_pre_truncation"]
+ tokenized_res_prompt["num_tokens_pre_truncation"]
)

return tokenized_prompt


Expand Down
57 changes: 42 additions & 15 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
HF Chat Templates prompt strategy
"""

import functools
import logging
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -64,14 +64,16 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):

if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

if self.processor:
text = self.processor.apply_chat_template(
turns,
_apply_chat_template = functools.partial(
self.processor.apply_chat_template,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
text = _apply_chat_template(
turns,
tokenize=False,
)
batch = self.processor(
text=text,
images=images,
Expand All @@ -85,15 +87,27 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
batch["num_tokens_pre_truncation"] = len(
_apply_chat_template(turns, tokenize=True)
)
return batch

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
_apply_chat_template = functools.partial(
self.tokenizer.apply_chat_template,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
inputs = _apply_chat_template(
turns,
truncation=True,
)
return {
"input_ids": inputs,
"num_tokens_pre_truncation": len(
_apply_chat_template(turns, truncation=False)
),
}

def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
Expand Down Expand Up @@ -237,20 +251,29 @@ def tokenize_prompt(self, prompt):
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
prompt_tokenized = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
all_turns_tokenized = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
if "attention_mask" not in all_turns_tokenized:
prompt_ids = prompt_tokenized["input_ids"]
input_ids = (
prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :]
)
tokenized_prompt["input_ids"] = input_ids
num_tokens_pre_truncation = all_turns_tokenized[
"num_tokens_pre_truncation"
]
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
input_ids = all_turns_tokenized["input_ids"]
num_tokens_pre_truncation = all_turns_tokenized[
"num_tokens_pre_truncation"
]
tokenized_prompt = all_turns_tokenized

if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
Expand All @@ -259,11 +282,14 @@ def tokenize_prompt(self, prompt):
labels = input_ids

tokenized_prompt["labels"] = labels
tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation

return tokenized_prompt

turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
tokenized_res = self.prompter.build_prompt(turns)
input_ids = tokenized_res["input_ids"]
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
labels = [IGNORE_TOKEN_ID] * len(input_ids)

last_eos_idx = -1
Expand Down Expand Up @@ -342,6 +368,7 @@ def tokenize_prompt(self, prompt):
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
"num_tokens_pre_truncation": num_tokens_pre_truncation,
}

def find_eos_token(self, input_ids, start_idx):
Expand Down
28 changes: 24 additions & 4 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""

import abc
import functools
import logging
from typing import Dict, List, Tuple, Union

Expand Down Expand Up @@ -60,18 +61,23 @@ def supports_batched(self):
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
empty = BatchEncoding(
data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0}
)
if not prompt:
LOG.warning("Empty text requested for tokenization.")
return empty

result = self.tokenizer(
prompt,
truncation=True,
_tokenize = functools.partial(
self.tokenizer,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
result = _tokenize(
prompt,
truncation=True,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
Expand All @@ -89,6 +95,20 @@ def _tokenize(
result["attention_mask"] = result["attention_mask"][1:]

result["labels"] = result["input_ids"].copy()

_all_tokens = _tokenize(prompt, truncation=False)
num_tokens_pre_truncation = len(_all_tokens["input_ids"])
if (
_all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id
and add_eos_token
):
num_tokens_pre_truncation += 1
if (
_all_tokens["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
num_tokens_pre_truncation -= 1
result["num_tokens_pre_truncation"] = num_tokens_pre_truncation
return result


Expand Down
10 changes: 9 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from axolotl.common.cli import TrainerCliArgs
from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
Expand Down Expand Up @@ -99,6 +100,8 @@ def train(
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_model_load(cfg, model)
if model.generation_config is not None:
model.generation_config.do_sample = True

Expand Down Expand Up @@ -148,7 +151,7 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
if cfg.local_rank == 0 and cfg.get("save_model_on_interrupt", True):

def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
Expand Down Expand Up @@ -269,6 +272,11 @@ def terminate_handler(_, __, model_weakref):
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

if cfg.deepspeed:
trainer.deepspeed.destroy()
trainer.accelerator.free_memory()
trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None

return model, tokenizer


Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,24 @@ def get_dataset_lengths(dataset):
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths


def plot_ascii_lengths_histogram(data, title, logger):
max_value = max(data)
bucket_width = 512
bins = np.arange(0, max_value + bucket_width, bucket_width)
histogram, _ = np.histogram(data, bins=bins)
top = " ".join(("-" * 10, title, "-" * 10))
bottom = "-" * len(top)
logger.info(top)
scale_factor = 40 / max(histogram)
for i, value in enumerate(histogram):
lower_bound = i * bucket_width
upper_bound = (i + 1) * bucket_width - 1
if value:
hist_bar = "□" * int(value * scale_factor)
else:
hist_bar = "x"
logger.info(f"{hist_bar} ({lower_bound}-{upper_bound} tokens, Count: {value})")
logger.info(bottom)
logger.info("\n")
Loading

0 comments on commit 0011a39

Please sign in to comment.