Skip to content

Commit

Permalink
allow to set mxied_precision in accelerate config when using a given …
Browse files Browse the repository at this point in the history
…deepspeed config
  • Loading branch information
XiaobingSuper committed Feb 12, 2025
1 parent 526925b commit a43dec6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
8 changes: 2 additions & 6 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,6 @@ def _validate_launch_command(args):

defaults = None
warned = []
mp_from_config_flag = False
# Get the default from the config file.
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
defaults = load_config_from_file(args.config_file)
Expand Down Expand Up @@ -1081,7 +1080,6 @@ def _validate_launch_command(args):
args.mixed_precision = "no"
else:
args.mixed_precision = defaults.mixed_precision
mp_from_config_flag = True
else:
if args.use_cpu or (args.use_xpu and torch.xpu.is_available()):
native_amp = True
Expand Down Expand Up @@ -1168,16 +1166,14 @@ def _validate_launch_command(args):
"\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
)
logger.warning(message)
return args, defaults, mp_from_config_flag
return args, defaults


def launch_command(args):
args, defaults, mp_from_config_flag = _validate_launch_command(args)
args, defaults = _validate_launch_command(args)
# Use the proper launcher
if args.use_deepspeed and not args.cpu:
args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
if mp_from_config_flag:
args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,6 @@ def _deepspeed_config_checks(self):
"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH",
"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH",
"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL",
"ACCELERATE_MIXED_PRECISION",
]
env_variable_names_to_ignore = [
name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore
Expand Down

0 comments on commit a43dec6

Please sign in to comment.