Skip to content

Commit

Permalink
Transformers 4.47.0 (axolotl-ai-cloud#2138)
Browse files Browse the repository at this point in the history
* bump transformers and trl

* fix: update trainer.log signature

* fix trl trainer.log interfaces

* broken 🦥 with latest transformers

* skip parent, call grandparent - yeah, super janky

* update HF HUB env var and fix reward trainer log since it doesn't directly override log

* also bump accelerate

* patches for llama ga

* detab the code to check

* fix whitespace for patch check

* play nicely with CI tests since we patch everytime

* fix pop default in case it doesn't exist

* more tweaks to make patches nicer in CI

* fix detab for when there are possibly multiple patches

---------

Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
winglian and NanoCode012 authored Dec 7, 2024
1 parent f9a7748 commit 743ba62
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile-cloud
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG

ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile-cloud-no-tmux
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG

ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.14.0
transformers==4.46.3
transformers==4.47.0
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.1.0
accelerate==1.2.0
datasets==3.1.0
deepspeed==0.15.4
pydantic==2.6.3
Expand Down Expand Up @@ -42,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs

trl==0.12.0
trl==0.12.1
zstandard==0.22.0
fastcore

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)

Expand Down
87 changes: 85 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,21 +957,23 @@ def create_accelerator_and_postprocess(self):

return res

def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
return super().log(logs, start_time)

def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
Expand Down Expand Up @@ -1155,6 +1157,18 @@ def training_step(
torch.cuda.empty_cache()
return loss

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)


class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Expand All @@ -1163,6 +1177,18 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):

tag_names = ["axolotl", "orpo"]

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)


class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Expand All @@ -1171,6 +1197,45 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):

tag_names = ["axolotl", "kto"]

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)


class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Expand All @@ -1179,6 +1244,18 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):

tag_names = ["axolotl", "cpo"]

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)


class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Expand All @@ -1187,6 +1264,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):

tag_names = ["axolotl", "reward"]

def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)


class TrainerBuilderBase(abc.ABC):
"""
Expand Down
207 changes: 207 additions & 0 deletions src/axolotl/monkeypatch/trainer_grad_accum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect

from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")

ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""

PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""

ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""

PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""


def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step


def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step


def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""

try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"

training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)

# load imports necessary
import transformers.trainer

items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True)
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)


def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward


def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward


def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""

try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"

forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)

# load imports necessary
import transformers.models.llama.modeling_llama

items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True)
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
Loading

0 comments on commit 743ba62

Please sign in to comment.