From 60861fe1fd743ee5ff2bf49beb287b1cea438ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=A1=B0=EC=A4=80=EB=9E=98?= Date: Fri, 7 Jun 2024 23:43:34 +0900 Subject: [PATCH] Implement JSON dump conversion for torch_dtype in TrainingArguments (#31224) * Implement JSON dump conversion for torch_dtype in TrainingArguments * Add unit test for converting torch_dtype in TrainingArguments to JSON * move unit test for converting torch_dtype into TrainerIntegrationTest class * reformating using ruff * convert dict_torch_dtype_to_str to private method _dict_torch_dtype_to_str --------- Co-authored-by: jun.4 --- src/transformers/training_args.py | 14 ++++++++++++++ tests/trainer/test_trainer.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index efb05682aae1..80b0740d20a3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2370,6 +2370,18 @@ def get_warmup_steps(self, num_training_steps: int): ) return warmup_steps + def _dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self._dict_torch_dtype_to_str(value) + def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates @@ -2388,6 +2400,8 @@ def to_dict(self): # Handle the accelerator_config if passed if is_accelerate_available() and isinstance(v, AcceleratorConfig): d[k] = v.to_dict() + self._dict_torch_dtype_to_str(d) + return d def to_json_string(self): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4d3fc5734005..af456a9bdad0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3445,6 +3445,35 @@ class CustomTrainingArguments(TrainingArguments): ) self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception)) + def test_torch_dtype_to_json(self): + @dataclasses.dataclass + class TorchDtypeTrainingArguments(TrainingArguments): + torch_dtype: torch.dtype = dataclasses.field( + default=torch.float32, + ) + + for dtype in [ + "float32", + "float64", + "complex64", + "complex128", + "float16", + "bfloat16", + "uint8", + "int8", + "int16", + "int32", + "int64", + "bool", + ]: + torch_dtype = getattr(torch, dtype) + with tempfile.TemporaryDirectory() as tmp_dir: + args = TorchDtypeTrainingArguments(output_dir=tmp_dir, torch_dtype=torch_dtype) + + args_dict = args.to_dict() + self.assertIn("torch_dtype", args_dict) + self.assertEqual(args_dict["torch_dtype"], dtype) + @require_torch @is_staging_test