Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support for fp8
Browse files Browse the repository at this point in the history
winglian committed Nov 10, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 105d0b3 commit ac49990
Showing 3 changed files with 9 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
@@ -483,6 +483,10 @@ def build(self, total_num_steps):
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
if self.cfg.fp8:
training_arguments_kwargs["fp16"] = False
training_arguments_kwargs["bf16"] = False

training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
4 changes: 3 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
@@ -70,7 +70,9 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False

if cfg.bf16 or cfg.bfloat16:
if cfg.fp8:
cfg.torch_dtype = torch.bfloat16
elif cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
2 changes: 2 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -268,6 +268,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if cfg.fp8:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"

trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset

0 comments on commit ac49990

Please sign in to comment.