diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 257cb3dceb31ec..2a4c6670910b31 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2959,16 +2959,14 @@ def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - optm = { - "optimizer": self.optimizer.state_dict(), - "shard_metadata": self.model.get_shard_metadata(), - } + from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP + + optm = FSDP.full_optim_state_dict(self.model, self.optimizer) xm.save( optm, os.path.join( - output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" - ), - master_only=False, + output_dir, f"{OPTIMIZER_NAME}" + ) ) else: xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -3050,23 +3048,27 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) ) checkpoint_file_exists = ( - glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + glob.glob(os.path.join(checkpoint, f"{OPTIMIZER_NAME}")) if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled else checkpoint_file_exists ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): + from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - optimizer_state = torch.load( - os.path.join( - checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" - ), - map_location="cpu", - ) - # We only need `optimizer` when resuming from checkpoint - optimizer_state = optimizer_state["optimizer"] + optimizer_state = None + if self.args.process_index == 0: + optimizer_state = torch.load( + os.path.join( + checkpoint, f"{OPTIMIZER_NAME}" + ), + map_location="cpu", + ) + + optimizer_state = FSDP.load_optim_state_dict(self.model, optimizer_state, self.optimizer) else: optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: