Skip to content

Commit b9898b1

Browse files
authored
【BUG】fix save torch_dtype in trainer (PaddlePaddle#2609)
1 parent 53230c0 commit b9898b1

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

paddleformers/trainer/unified_checkpoint/load_save_single_card.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def save_single_card_checkpoint(model_to_save, output_dir, save_to_hf=False):
101101
logger.warning("Asynchronous saving is not supported for single card environment currently.")
102102
save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename), save_to_hf=save_to_hf)
103103

104-
save_model_config(model_to_save, output_dir)
104+
save_model_config(model_to_save, output_dir, save_to_hf)
105105

106106

107107
def save_single_card_optimizer(model, optimizer, output_dir):

paddleformers/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None,
166166
json.dump(sharded_index, f, indent=4)
167167

168168
if self.args.should_save:
169-
save_model_config(model_to_save, save_directory)
169+
save_model_config(model_to_save, save_directory, save_to_hf)
170170

171171
empty_device_cache()
172172

paddleformers/trainer/unified_checkpoint/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def is_sharding_split_param_mode(args):
739739
)
740740

741741

742-
def save_model_config(model_to_save, save_directory):
742+
def save_model_config(model_to_save, save_directory, save_to_hf=False):
743743
"""
744744
Save model config.
745745
"""
@@ -770,7 +770,7 @@ def save_config(model_to_save):
770770
else:
771771
config_to_save.architectures = [clean_model_class_name(model_to_save.__class__.__name__)]
772772

773-
config_to_save.save_pretrained(save_directory)
773+
config_to_save.save_pretrained(save_directory, save_to_hf=save_to_hf)
774774
# save generation config
775775
if model_to_save.can_generate():
776776
model_to_save.generation_config.save_pretrained(save_directory)

paddleformers/transformers/model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,7 @@ def constructed_from_pretrained_config(cls, init_func=None) -> bool:
16041604
"""
16051605
return cls.config_class is not None and issubclass(cls.config_class, PretrainedConfig)
16061606

1607-
def save_model_config(self, save_dir: str):
1607+
def save_model_config(self, save_dir: str, **kwargs):
16081608
"""
16091609
Deprecated, please use `.config.save_pretrained()` instead.
16101610
Saves model configuration to a file named "config.json" under `save_dir`.
@@ -1613,7 +1613,7 @@ def save_model_config(self, save_dir: str):
16131613
save_dir (str): Directory to save model_config file into.
16141614
"""
16151615
logger.warning("The `save_model_config` is deprecated! Please use `.config.save_pretrained()` instead.")
1616-
self.config.save_pretrained(save_dir)
1616+
self.config.save_pretrained(save_dir, **kwargs)
16171617

16181618
def save_to_hf_hub(
16191619
self,

0 commit comments

Comments
 (0)