Skip to content

Commit

Permalink
use updated accelerate and fix deepspeed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 5, 2024
1 parent b614fb2 commit 7a89595
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ peft==0.12.0
transformers==4.44.2
tokenizers>=0.19.1
bitsandbytes==0.43.3
accelerate==0.34.0
accelerate==0.34.2
datasets==2.20.0
deepspeed==0.14.4
pydantic==2.6.3
Expand Down
6 changes: 1 addition & 5 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,7 @@ def terminate_handler(_, __, model_weakref):
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
try:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
except TypeError:
trainer.accelerator.state.fsdp_plugin.state_dict_type = state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type()
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")

if cfg.relora_steps:
Expand Down
6 changes: 4 additions & 2 deletions tests/e2e/multigpu/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent


class TestMultiGPULlama(unittest.TestCase):
"""
Expand Down Expand Up @@ -372,7 +374,7 @@ def test_ds_zero3_packed(self, temp_dir):
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": "deepspeed_configs/zero3_bf16.yaml",
"deepspeed": AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.yaml",
}
)

Expand Down Expand Up @@ -431,7 +433,7 @@ def test_ds_zero3_qlora_packed(self, temp_dir):
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": "deepspeed_configs/zero3_bf16.yaml",
"deepspeed": AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.yaml",
}
)

Expand Down

0 comments on commit 7a89595

Please sign in to comment.