From c60e8bde17ced869ead85d3c81d4d24e28342a6e Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 24 Sep 2024 12:35:52 +0800 Subject: [PATCH 1/2] support acc fsdp optim_state_dict --- src/transformers/trainer.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 257cb3dceb31ec..a33d07c9290b8f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2959,16 +2959,14 @@ def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - optm = { - "optimizer": self.optimizer.state_dict(), - "shard_metadata": self.model.get_shard_metadata(), - } + from torchacc.dist.fsdp import FullyShardedDataParallel as ACC_FSDP + + optm = ACC_FSDP.full_optim_state_dict(self.model, self.optimizer) xm.save( optm, os.path.join( - output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" - ), - master_only=False, + output_dir, f"{OPTIMIZER_NAME}" + ) ) else: xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -3050,23 +3048,27 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) ) checkpoint_file_exists = ( - glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + glob.glob(os.path.join(checkpoint, f"{OPTIMIZER_NAME}")) if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled else checkpoint_file_exists ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): + from torchacc.dist.fsdp import FullyShardedDataParallel as ACC_FSDP # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - optimizer_state = torch.load( - os.path.join( - checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" - ), - map_location="cpu", - ) - # We only need `optimizer` when resuming from checkpoint - optimizer_state = optimizer_state["optimizer"] + optimizer_state = None + if self.args.process_index == 0: + optimizer_state = torch.load( + os.path.join( + checkpoint, f"{OPTIMIZER_NAME}" + ), + map_location="cpu", + ) + + optimizer_state = ACC_FSDP.load_optim_state_dict(self.model, optimizer_state, self.optimizer) else: optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: From 1b2922b61a10ed1c05abc073269993d99f110254 Mon Sep 17 00:00:00 2001 From: shw Date: Tue, 24 Sep 2024 14:06:22 +0800 Subject: [PATCH 2/2] change ACC_FSDP to FSDP --- src/transformers/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a33d07c9290b8f..2a4c6670910b31 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2959,9 +2959,9 @@ def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - from torchacc.dist.fsdp import FullyShardedDataParallel as ACC_FSDP + from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP - optm = ACC_FSDP.full_optim_state_dict(self.model, self.optimizer) + optm = FSDP.full_optim_state_dict(self.model, self.optimizer) xm.save( optm, os.path.join( @@ -3056,7 +3056,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): - from torchacc.dist.fsdp import FullyShardedDataParallel as ACC_FSDP + from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: optimizer_state = None @@ -3068,7 +3068,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): map_location="cpu", ) - optimizer_state = ACC_FSDP.load_optim_state_dict(self.model, optimizer_state, self.optimizer) + optimizer_state = FSDP.load_optim_state_dict(self.model, optimizer_state, self.optimizer) else: optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: