Skip to content

Commit

Permalink
Lay out params in a contiguous buffer using a new ParamAndGradBuffer
Browse files Browse the repository at this point in the history
- Re-map parameters only when using the distributed optimizer
- Remove unnecessary param copying logic after all-gather
- Unmap weight_tensor attributes if they exist to reduce memory footprint
  • Loading branch information
deepakn94 committed Mar 16, 2024
1 parent 73ce965 commit 293e104
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 165 deletions.
83 changes: 49 additions & 34 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import parallel_state
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from .grad_buffer import GradBuffer
from .grad_buffer import ParamAndGradBuffer


class DistributedDataParallel(MegatronModule):
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
self.bucket_size = bucket_size

self.module = module
self.param_to_grad_buffer = {}
self.param_to_buffer = {}

# Group parameters by their gradient type.
param_to_name = {}
Expand All @@ -91,28 +91,30 @@ def __init__(
else:
expert_parallel_params.append(param)

def allocate_grad_buffers_for_parameters(
def allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor=1.0,
):
grad_dtype_to_params = {}
param_and_grad_dtype_to_params = {}

# Group parameters by their gradient type.
for param in input_params:
if not param.requires_grad:
continue

dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype
param_dtype = param.dtype
grad_dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype

params = grad_dtype_to_params.get(dtype, [])
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
grad_dtype_to_params[dtype] = params
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params

# Allocate the grad buffers and map the grads.
grad_buffers = []
for dtype, params in grad_dtype_to_params.items():
grad_buffers.append(
GradBuffer(
dtype,
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
ParamAndGradBuffer(
param_dtype,
grad_dtype,
params,
data_parallel_group,
bucket_size,
Expand All @@ -124,26 +126,39 @@ def allocate_grad_buffers_for_parameters(
)
)
for param in params:
self.param_to_grad_buffer[param] = grad_buffers[-1]
self.param_to_buffer[param] = buffers[-1]

return grad_buffers
return buffers

data_parallel_world_size = torch.distributed.get_world_size(data_parallel_group)

# Allocate the grad buffers for dense params' grads.
self.grad_buffers = allocate_grad_buffers_for_parameters(
# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
dense_params,
data_parallel_group,
gradient_scaling_factor=1.0 / data_parallel_world_size,
)

# Allocate separate grad buffers for expert parallel params' grads.
self.expert_parallel_grad_buffers = allocate_grad_buffers_for_parameters(
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers = allocate_buffers_for_parameters(
expert_parallel_params,
expert_data_parallel_group,
gradient_scaling_factor=1.0 / data_parallel_world_size,
)

# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.use_distributed_optimizer:

@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None

self.module.apply(unmap_weight_tensor)

# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
Expand All @@ -154,7 +169,7 @@ def allocate_grad_buffers_for_parameters(
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer))
grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
self.grad_accs.append(grad_acc)

def forward(self, *inputs, **kwargs):
Expand All @@ -164,7 +179,9 @@ def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)

def _make_param_hook(
self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer]
self,
param: torch.nn.Parameter,
param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
Expand All @@ -183,7 +200,7 @@ def param_hook(*unused):
param.grad = None

if self.overlap_grad_reduce:
param_to_grad_buffer[param].register_grad_ready(param)
param_to_buffer[param].register_grad_ready(param)

return param_hook

Expand All @@ -192,13 +209,13 @@ def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers:
grad_buffer.is_last_microbatch = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = False
try:
yield
finally:
for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers:
grad_buffer.is_last_microbatch = True
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = True

def start_grad_sync(self, *unused):
"""
Expand All @@ -209,8 +226,8 @@ def start_grad_sync(self, *unused):
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers:
grad_buffer.start_grad_sync()
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.start_grad_sync()

def finish_grad_sync(self):
"""
Expand All @@ -221,21 +238,19 @@ def finish_grad_sync(self):
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers:
grad_buffer.finish_grad_sync()
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.finish_grad_sync()

def zero_grad_buffer(self, zero_buffer):
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
When zero_buffer is set to True, the underlying grad buffer is zeroed out.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers:
grad_buffer.reset(zero_buffer)
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()

def broadcast_params(self):
"""
Expand Down
Loading

0 comments on commit 293e104

Please sign in to comment.