diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8c9b99191f632f..15af9cd8d5a12e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3407,8 +3407,6 @@ def _save_tpu(self, output_dir: Optional[str] = None): logger.info(f"Saving model checkpoint to {output_dir}") model = self.model xm.mark_step() - if self.args.save_safetensors: - model.to("cpu") if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) @@ -3423,13 +3421,13 @@ def _save_tpu(self, output_dir: Optional[str] = None): self.accelerator.unwrap_model(model).save_pretrained( output_dir, is_main_process=self.args.should_save, - state_dict=model.state_dict(), + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), save_function=xm.save, safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = model.state_dict() + state_dict = xm._maybe_convert_to_cpu(model.state_dict()) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: model.save_pretrained( @@ -3437,15 +3435,11 @@ def _save_tpu(self, output_dir: Optional[str] = None): is_main_process=self.args.should_save, save_function=xm.save, safe_serialization=self.args.save_safetensors, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir) - # We moved the model from TPU -> CPU for saving the weights. - # Now we should move it back to subsequent compute still works. - if self.args.save_safetensors: - model.to(self.args.device) - def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir