Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.autocast with ZeRO #6993

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ def get_amp_params(param_dict):
return False


def get_torch_autocast_enabled(param_dict):
if TORCH_AUTOCAST in param_dict.keys():
return get_scalar_param(param_dict[TORCH_AUTOCAST], TORCH_AUTOCAST_ENABLED, TORCH_AUTOCAST_ENABLED_DEFAULT)
else:
return False


def get_torch_autocast_dtype(param_dict):
if TORCH_AUTOCAST in param_dict:
if TORCH_AUTOCAST_DTYPE in param_dict[TORCH_AUTOCAST]:
try:
return DtypeEnum(param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]).value
except KeyError:
raise ValueError(
f"Invalid dtype for torch autocast: {param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]}")
return None


def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
Expand Down Expand Up @@ -835,6 +853,8 @@ def _initialize_params(self, param_dict):
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.torch_autocast_enabled = get_torch_autocast_enabled(param_dict)
self.torch_autocast_dtype = get_torch_autocast_dtype(param_dict)
self.loss_scale = get_loss_scale(param_dict)
self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)
Expand Down
17 changes: 17 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,23 @@
AMP_ENABLED = "enabled"
AMP_ENABLED_DEFAULT = False

#########################################
# Torch AMP support
#########################################
TORCH_AUTOCAST_FORMAT = '''
PyTorch autocast config should be of the format:
"torch_autocast": {
"enabled": true,
"dtype": "bfloat16",
}
'''
TORCH_AUTOCAST = "torch_autocast"

TORCH_AUTOCAST_ENABLED = "enabled"
TORCH_AUTOCAST_ENABLED_DEFAULT = False
TORCH_AUTOCAST_DTYPE = "dtype"
TORCH_AUTOCAST_DTYPE_DEFAULT = None

#########################################
# Gradient clipping
#########################################
Expand Down
15 changes: 14 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@

from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from deepspeed.runtime.torch_autocast import init_autocast_params

from .pipe.module import PipelineModule
from .utils import get_ma_status
Expand Down Expand Up @@ -311,6 +312,9 @@ def __init__(self,
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)

if self.torch_autocast_enabled():
init_autocast_params(self, self.torch_autocast_dtype())

if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler()
Expand Down Expand Up @@ -850,6 +854,12 @@ def amp_enabled(self):
def amp_params(self):
return self._config.amp_params

def torch_autocast_enabled(self):
return self._config.torch_autocast_enabled

def torch_autocast_dtype(self):
return self._config.torch_autocast_dtype

def fp16_auto_cast(self):
return self._config.fp16_auto_cast

Expand Down Expand Up @@ -1909,7 +1919,10 @@ def forward(self, *inputs, **kwargs):
if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)
with torch.autocast(device_type=get_accelerator().device_name(),
dtype=self.torch_autocast_dtype(),
enabled=self.torch_autocast_enabled()):
loss = self.module(*inputs, **kwargs)

if self.zero_optimization_partition_weights():
# Disable automated discovery of external parameters
Expand Down
53 changes: 53 additions & 0 deletions deepspeed/runtime/torch_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Set

import torch

LOWER_PRECISION_SAFE_MODULES = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
]

TORCH_AUTOCAST_INITIALIZED = False


def _validate_auto_cast_settings(engine):

assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16"
assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16"

assert all(p.dtype == torch.float32
for p in engine.parameters()), "All parameters must be float32 for torch autocast"
assert engine.communication_data_type == torch.float32, "Communication data type must be float32 for torch autocast"


def init_autocast_params(engine, dtype: torch.dtype) -> None:

_validate_auto_cast_settings(engine)
model = engine.module

for module in model.modules():
if module.__class__ in LOWER_PRECISION_SAFE_MODULES:
for p in module.parameters(recurse=False):
p.autocast_dtype = dtype

global TORCH_AUTOCAST_INITIALIZED
TORCH_AUTOCAST_INITIALIZED = True


def is_autocast_initialized() -> bool:
return TORCH_AUTOCAST_INITIALIZED


def get_autocast_dtype(param: torch.nn.Parameter) -> torch.dtype:
return param.autocast_dtype if hasattr(param, "autocast_dtype") else param.dtype


def get_all_autocast_dtypes(params: Iterable) -> Set[torch.dtype]:
return {get_autocast_dtype(p) for p in params}
Loading
Loading