From 293e10419fd1b79c8680a0f4a206fc0a373729b5 Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Wed, 14 Feb 2024 14:14:35 -0800 Subject: [PATCH] Lay out params in a contiguous buffer using a new ParamAndGradBuffer - Re-map parameters only when using the distributed optimizer - Remove unnecessary param copying logic after all-gather - Unmap weight_tensor attributes if they exist to reduce memory footprint --- .../distributed/distributed_data_parallel.py | 83 +++++++----- megatron/core/distributed/grad_buffer.py | 114 ++++++++++++----- megatron/core/optimizer/__init__.py | 20 +-- megatron/core/optimizer/distrib_optimizer.py | 118 +++++------------- megatron/training.py | 5 +- 5 files changed, 175 insertions(+), 165 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index d8cc637236..d664c32066 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -8,7 +8,7 @@ from .. import parallel_state from ..transformer.module import MegatronModule from ..transformer.transformer_config import TransformerConfig -from .grad_buffer import GradBuffer +from .grad_buffer import ParamAndGradBuffer class DistributedDataParallel(MegatronModule): @@ -73,7 +73,7 @@ def __init__( self.bucket_size = bucket_size self.module = module - self.param_to_grad_buffer = {} + self.param_to_buffer = {} # Group parameters by their gradient type. param_to_name = {} @@ -91,28 +91,30 @@ def __init__( else: expert_parallel_params.append(param) - def allocate_grad_buffers_for_parameters( + def allocate_buffers_for_parameters( input_params, data_parallel_group, gradient_scaling_factor=1.0, ): - grad_dtype_to_params = {} + param_and_grad_dtype_to_params = {} # Group parameters by their gradient type. for param in input_params: if not param.requires_grad: continue - dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype + param_dtype = param.dtype + grad_dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype - params = grad_dtype_to_params.get(dtype, []) + params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), []) params.append(param) - grad_dtype_to_params[dtype] = params + param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params # Allocate the grad buffers and map the grads. - grad_buffers = [] - for dtype, params in grad_dtype_to_params.items(): - grad_buffers.append( - GradBuffer( - dtype, + buffers = [] + for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items(): + buffers.append( + ParamAndGradBuffer( + param_dtype, + grad_dtype, params, data_parallel_group, bucket_size, @@ -124,26 +126,39 @@ def allocate_grad_buffers_for_parameters( ) ) for param in params: - self.param_to_grad_buffer[param] = grad_buffers[-1] + self.param_to_buffer[param] = buffers[-1] - return grad_buffers + return buffers data_parallel_world_size = torch.distributed.get_world_size(data_parallel_group) - # Allocate the grad buffers for dense params' grads. - self.grad_buffers = allocate_grad_buffers_for_parameters( + # Allocate the param+grad buffers for dense params' grads. + self.buffers = allocate_buffers_for_parameters( dense_params, data_parallel_group, gradient_scaling_factor=1.0 / data_parallel_world_size, ) - # Allocate separate grad buffers for expert parallel params' grads. - self.expert_parallel_grad_buffers = allocate_grad_buffers_for_parameters( + # Allocate separate param+grad buffers for expert parallel params' grads. + self.expert_parallel_buffers = allocate_buffers_for_parameters( expert_parallel_params, expert_data_parallel_group, gradient_scaling_factor=1.0 / data_parallel_world_size, ) + # Delete references to weight_tensor if they exist since we don't want two parameter copies + # if we re-mapped parameters (which happens when we use the distributed optimizer). + # This is a temporary workaround around a TE bug that is fixed with + # https://github.com/NVIDIA/TransformerEngine/pull/719. + if self.use_distributed_optimizer: + + @torch.no_grad() + def unmap_weight_tensor(m): + if hasattr(m, 'weight_tensor'): + m.weight_tensor = None + + self.module.apply(unmap_weight_tensor) + # Register backward hook. # Accumulation function for the gradients need to be stored so they # don't go out of scope. @@ -154,7 +169,7 @@ def allocate_grad_buffers_for_parameters( param_tmp = param.expand_as(param) # Get the gradient accumulator function. grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) + grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer)) self.grad_accs.append(grad_acc) def forward(self, *inputs, **kwargs): @@ -164,7 +179,9 @@ def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) def _make_param_hook( - self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer] + self, + param: torch.nn.Parameter, + param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer], ): """ Creates the all-reduce / reduce-scatter hook for backprop. @@ -183,7 +200,7 @@ def param_hook(*unused): param.grad = None if self.overlap_grad_reduce: - param_to_grad_buffer[param].register_grad_ready(param) + param_to_buffer[param].register_grad_ready(param) return param_hook @@ -192,13 +209,13 @@ def no_sync(self): """ Context manager that turns off gradient synchronization. """ - for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers: - grad_buffer.is_last_microbatch = False + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.is_last_microbatch = False try: yield finally: - for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers: - grad_buffer.is_last_microbatch = True + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.is_last_microbatch = True def start_grad_sync(self, *unused): """ @@ -209,8 +226,8 @@ def start_grad_sync(self, *unused): calls. When overlap_grad_reduce is set to False, calls synchronous communication ops. """ - for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers: - grad_buffer.start_grad_sync() + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.start_grad_sync() def finish_grad_sync(self): """ @@ -221,21 +238,19 @@ def finish_grad_sync(self): calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops. """ - for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers: - grad_buffer.finish_grad_sync() + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.finish_grad_sync() - def zero_grad_buffer(self, zero_buffer): + def zero_grad_buffer(self): """ Zeros out all grad buffers. Needs to be called at the beginning of each training iteration. - - When zero_buffer is set to True, the underlying grad buffer is zeroed out. """ for param in self.module.parameters(): if param.requires_grad: param.grad_added_to_main_grad = False - for grad_buffer in self.grad_buffers + self.expert_parallel_grad_buffers: - grad_buffer.reset(zero_buffer) + for buffer in self.buffers + self.expert_parallel_buffers: + buffer.reset() def broadcast_params(self): """ diff --git a/megatron/core/distributed/grad_buffer.py b/megatron/core/distributed/grad_buffer.py index 9b4202596b..dc4d17b32b 100644 --- a/megatron/core/distributed/grad_buffer.py +++ b/megatron/core/distributed/grad_buffer.py @@ -2,8 +2,9 @@ import math import os +from enum import Enum from logging import getLogger -from typing import Dict, List +from typing import Dict, List, Optional import torch @@ -12,6 +13,11 @@ logger = getLogger(__name__) +class BufferType(Enum): + PARAM = 1 + GRAD = 2 + + def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int): """ Shard buffer into data_parallel_world_size chunks of equal size. @@ -32,8 +38,9 @@ class Bucket: Arguments: params: List of parameters whose gradients are collated in this bucket. - data: View in larger GradBuffer that this bucket is responsible for. - offset: Offset of this bucket's view in the larger GradBuffer. + param_data: View in larger ParamAndGradBuffer.param_data that this bucket is responsible for. + grad_data: View in larger ParamAndGradBuffer.grad_data that this bucket is responsible for. + offset: Offset of this bucket's view in the larger ParamAndGradBuffer. numel_unpadded: Number of unpadded elements in bucket. data_parallel_group: Data-parallel process group. data_parallel_world_size: World size using the data-parallel group group. @@ -51,7 +58,8 @@ class Bucket: def __init__( self, params: List[torch.nn.Parameter], - data: torch.Tensor, + param_data: Optional[torch.Tensor], + grad_data: torch.Tensor, offset: int, numel_unpadded: int, data_parallel_group: torch.distributed.ProcessGroup, @@ -68,7 +76,8 @@ def __init__( self.params_list = params self.params = set(params) self.params_with_grad = set() - self.data = data + self.param_data = param_data + self.grad_data = grad_data # The distributed optimizer needs to keep track of this bucket's offset # within the full grad_buffer. self.offset = offset @@ -108,28 +117,28 @@ def start_grad_sync(self): # prior to data-parallel all-reduce / reduce-scatter. if self.check_for_nan_in_grad: global_rank = torch.distributed.get_rank() - norm = self.data.norm(p=2) + norm = self.grad_data.norm(p=2) assert not norm.isnan(), ( f'Rank {global_rank}: found NaN in local grad norm in ' f'backward pass before data-parallel communication collective. ' f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' ) - self.data *= self.gradient_scaling_factor + self.grad_data *= self.gradient_scaling_factor # Use async_op only when overlap_grad_reduce is True. if self.use_distributed_optimizer: - local_data_view = shard_buffer(self.data, self.data_parallel_world_size)[ + local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[ self.data_parallel_rank ] self.communication_handle = torch.distributed._reduce_scatter_base( local_data_view, - self.data, + self.grad_data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce, ) else: self.communication_handle = torch.distributed.all_reduce( - self.data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce + self.grad_data, group=self.data_parallel_group, async_op=self.overlap_grad_reduce ) self.communication_issued = True @@ -169,14 +178,16 @@ def register_grad_ready(self, param: torch.nn.Parameter): self.start_grad_sync() -class GradBuffer: +class ParamAndGradBuffer: """ - Groups gradients into a contiguous buffer, and then breaks the buffer into buckets with - roughly `bucket_size` parameters each. + Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into + buckets with roughly `bucket_size` parameters each. Arguments: - dtype: Type of underlying tensor. - params: List of parameters whose gradients are collated in the underlying tensor. + param_dtype: Type of param tensor. + grad_dtype: Type of grad tensor. + params: List of parameters whose parameters and gradients are collated in the underlying + tensor. data_parallel_group: Data-parallel process group. bucket_size: The rough size of each bucket in terms of number of parameters. param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes). @@ -193,7 +204,8 @@ class GradBuffer: def __init__( self, - dtype: torch.dtype, + param_dtype: torch.dtype, + grad_dtype: torch.dtype, params: List[torch.nn.Parameter], data_parallel_group: torch.distributed.ProcessGroup, bucket_size: int, @@ -212,7 +224,8 @@ def __init__( del unique_params # Store attributes that will be needed later. - self.dtype = dtype + self.param_dtype = param_dtype + self.grad_dtype = grad_dtype self.data_parallel_group = data_parallel_group self.data_parallel_world_size = torch.distributed.get_world_size( group=self.data_parallel_group @@ -318,11 +331,23 @@ def _does_param_require_new_bucket(param): self.numel = data_end_index if use_distributed_optimizer: assert self.numel % self.data_parallel_world_size == 0 - self.data = torch.zeros( - self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False, + self.param_data = None + # Only re-map param tensors if using distributed optimizer. + if self.use_distributed_optimizer: + self.param_data = torch.zeros( + self.numel, + dtype=self.param_dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + self.grad_data = torch.zeros( + self.numel, + dtype=self.grad_dtype, + device=torch.cuda.current_device(), + requires_grad=False, ) - # Finally, map main_grad fields for each parameter with a .grad field. + # Finally, map param.data and param.main_grad fields to buffers. bucket_params = set() bucket_data_start_index = 0 cur_bucket_id = 0 @@ -330,7 +355,21 @@ def _does_param_require_new_bucket(param): if not param.requires_grad: continue data_start_index, data_end_index, bucket_id = self.param_index_map[param] - param.main_grad = self._get(param.data.shape, data_start_index) + + # Assign param.data to appropriate segment of self.param_data. + if self.param_data is not None: + old_param_data = param.data + param.data = self._get( + param.data.shape, data_start_index, buffer_type=BufferType.PARAM + ) + assert old_param_data._base is None + # Copy tensor values (from initialization or checkpoint). + param.data.detach().copy_(old_param_data) + del old_param_data + + param.main_grad = self._get( + param.data.shape, data_start_index, buffer_type=BufferType.GRAD + ) if bucket_id != cur_bucket_id: bucket_data_end_index = _pad_if_needed(data_start_index) self._set_bucket( @@ -374,14 +413,20 @@ def _does_param_require_new_bucket(param): for param in bucket.params: logger.info(f' {param_to_name[param]}') - def _get(self, shape: torch.Size, start_index: int) -> torch.Tensor: + def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor: """ Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`. """ end_index = start_index + shape.numel() assert end_index <= self.numel, 'Requested tensor is out of buffer range' - buffer_tensor = self.data[start_index:end_index] + if buffer_type == BufferType.PARAM: + assert self.param_data is not None + buffer_tensor = self.param_data[start_index:end_index] + elif buffer_type == BufferType.GRAD: + buffer_tensor = self.grad_data[start_index:end_index] + else: + raise Exception("Illegal buffer type provided to GradBuffer._get() function") buffer_tensor = buffer_tensor.view(shape) return buffer_tensor @@ -405,11 +450,19 @@ def _set_bucket( assert end_index % self.data_parallel_world_size == 0 assert (start_index, end_index) == self.bucket_indices[bucket_id] - # Get appropriate view into global GradBuffer. - bucket_data = self._get(torch.Size([end_index - start_index]), start_index) + # Get appropriate view into global ParamAndGradBuffer. + bucketed_param_data = None + if self.param_data is not None: + bucketed_param_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM + ) + bucketed_grad_data = self._get( + torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD + ) bucket = Bucket( params=bucket_params, - data=bucket_data, + param_data=bucketed_param_data, + grad_data=bucketed_grad_data, offset=start_index, numel_unpadded=numel_unpadded, data_parallel_group=self.data_parallel_group, @@ -424,15 +477,12 @@ def _set_bucket( assert bucket_param not in self.param_to_bucket self.param_to_bucket[bucket_param] = bucket - def reset(self, zero_buffer): + def reset(self): """ - Zero out the underlying buffer and reset all buckets in preparation for the next + Zero out the underlying grad_buffer and reset all buckets in preparation for the next iteration of training. - - When zero_buffer is set to True, the underlying buffer is zeroed out. """ - if zero_buffer: - self.data.zero_() + self.grad_data.zero_() for bucket in self.buckets: bucket.reset() self.is_last_microbatch = True diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 639c61e56a..3c4d0c02ab 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -89,7 +89,7 @@ def get_param_groups(model_chunks, no_weight_decay_cond, scale_lr_cond, lr_mult) def get_megatron_optimizer_based_on_param_groups( config, param_groups, - per_model_grad_buffers=None, + per_model_buffers=None, data_parallel_group=None, data_parallel_group_gloo=None, data_parallel_group_idx=None, @@ -101,7 +101,7 @@ def get_megatron_optimizer_based_on_param_groups( Args: param_groups (list): list of parameter groups. - per_model_grad_buffers (list, optional): list of gradient buffers for + per_model_buffers (list, optional): list of buffers for distributed optimizer. Defaults to None. data_parallel_group (ProcessGroup, optional): data parallel group for distributed optimizer. Defaults to None. @@ -184,7 +184,7 @@ def init_state_fn(opt): if config.use_distributed_optimizer: optimizer = DistributedOptimizer( *optimizer_args, - per_model_grad_buffers=per_model_grad_buffers, + per_model_buffers=per_model_buffers, data_parallel_group=data_parallel_group, data_parallel_group_gloo=data_parallel_group_gloo, overlap_param_gather=config.overlap_param_gather, @@ -225,12 +225,12 @@ def get_megatron_optimizer( param_groups = get_param_groups(model_chunks, no_weight_decay_cond, scale_lr_cond, lr_mult) # Collect grad buffers for distributed optimizer. - per_model_grad_buffers = {} - per_model_ep_grad_buffers = {} + per_model_buffers = {} + per_model_ep_buffers = {} for model_idx, model_chunk in enumerate(model_chunks): - if hasattr(model_chunk, 'grad_buffers'): - per_model_grad_buffers[model_idx] = model_chunk.grad_buffers - per_model_ep_grad_buffers[model_idx] = model_chunk.expert_parallel_grad_buffers + if hasattr(model_chunk, 'buffers'): + per_model_buffers[model_idx] = model_chunk.buffers + per_model_ep_buffers[model_idx] = model_chunk.expert_parallel_buffers # Split param groups into dense and moe. dense_param_groups = list(filter(lambda g: not g['is_expert_parallel'], param_groups)) @@ -242,7 +242,7 @@ def get_megatron_optimizer( get_megatron_optimizer_based_on_param_groups( config, param_groups=dense_param_groups, - per_model_grad_buffers=per_model_grad_buffers, + per_model_buffers=per_model_buffers, data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True), data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True), data_parallel_group_idx=model_parallel_rank, @@ -255,7 +255,7 @@ def get_megatron_optimizer( get_megatron_optimizer_based_on_param_groups( config, param_groups=moe_param_groups, - per_model_grad_buffers=per_model_ep_grad_buffers, + per_model_buffers=per_model_ep_buffers, data_parallel_group=mpu.get_data_modulo_expert_parallel_group(), data_parallel_group_gloo=mpu.get_data_modulo_expert_parallel_group_gloo(), data_parallel_group_idx=expert_parallel_rank * model_parallel_world_size diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index d706f8717f..ad30940191 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -62,8 +62,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): use any loss scale. Note that for `bf16 = True`, we can have a constnat gradient scaler. Also for `bf16 = False`, we always require a grad scaler. - grad_buffers: the implementation of the distributed optimizer is - centered on using the contiguous grad buffer for communicating + buffers: the implementation of the distributed optimizer is + centered on using a contiguous buffer for communicating grads & params between the model state and the optimizer state. You can find a more detailed description in this document https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md @@ -144,8 +144,7 @@ def build_model_gbuf_range(cls, grad_buffer, bucket_index): data_parallel_world_size = grad_buffer.data_parallel_group.size() bucket = grad_buffer.buckets[bucket_index] - bucket_buffer = bucket.data - gbuf_size = bucket_buffer.numel() + gbuf_size = bucket.grad_data.numel() assert ( gbuf_size % data_parallel_world_size == 0 ), f"Each bucket's buffer size should be divisible by {data_parallel_world_size}" @@ -189,10 +188,10 @@ def build_gbuf_range_map(cls, grad_buffer): shard is 1/dp_world_size of the bucket). Args: - grad_buffer (GradBuffer): grad buffer to build mapping for. + grad_buffer (ParamAndGradBuffer): grad buffer to build mapping for. """ return { - grad_buffer.dtype: [ + (grad_buffer.param_dtype, grad_buffer.grad_dtype): [ cls.build_model_gbuf_range(grad_buffer, bucket_index) for bucket_index in range(len(grad_buffer.buckets)) ] @@ -380,7 +379,7 @@ def __init__( params_dtype, grad_scaler, init_state_fn, - per_model_grad_buffers, + per_model_buffers, overlap_param_gather, data_parallel_group, data_parallel_group_gloo, @@ -413,29 +412,43 @@ def __init__( ), "Only Adam currently supported, due to checkpointing requirements." # Model grad buffer ranges. - assert per_model_grad_buffers, "grad_buffers must be provided" - self.grad_buffers = list(itertools.chain(*per_model_grad_buffers.values())) - self.per_model_grad_buffers = per_model_grad_buffers + assert per_model_buffers, "buffers must be provided" + self.buffers = list(itertools.chain(*per_model_buffers.values())) + self.per_model_buffers = per_model_buffers self.data_parallel_group = data_parallel_group self.data_parallel_group_gloo = data_parallel_group_gloo self.data_parallel_group_idx = data_parallel_group_idx self.gbuf_idx_to_model_idx_map = {} gbuf_idx = 0 - for model_idx, grad_buffers in self.per_model_grad_buffers.items(): - for _ in grad_buffers: + for model_idx, buffers in self.per_model_buffers.items(): + for _ in buffers: self.gbuf_idx_to_model_idx_map[gbuf_idx] = model_idx gbuf_idx += 1 self.gbuf_ranges = [] self.per_bucket_numel = [] self.per_bucket_numel_unpadded = [] - for grad_buffer in self.grad_buffers: + self.param_buffers = [] + for buffer in self.buffers: + # self.param_buffers needs handles to each param_buffer bucket to coordinate all-gather. + self.param_buffers.append([]) + for bucket in buffer.buckets: + self.param_buffers[-1].append(bucket.param_data) + self.per_bucket_numel.append( - {grad_buffer.dtype: [bucket.data.numel() for bucket in grad_buffer.buckets]} + { + (buffer.param_dtype, buffer.grad_dtype): [ + bucket.grad_data.numel() for bucket in buffer.buckets + ] + } ) self.per_bucket_numel_unpadded.append( - {grad_buffer.dtype: [bucket.numel_unpadded for bucket in grad_buffer.buckets]} + { + (buffer.param_dtype, buffer.grad_dtype): [ + bucket.numel_unpadded for bucket in buffer.buckets + ] + } ) - self.gbuf_ranges.append(self.build_gbuf_range_map(grad_buffer)) + self.gbuf_ranges.append(self.build_gbuf_range_map(buffer)) self.model_param_gbuf_map = self.build_model_param_gbuf_map(self.gbuf_ranges) # Optimizer ranges. @@ -454,36 +467,12 @@ def __init__( self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges ) - # Initialize param buffers. - # - These are views on the DDP model's grad buffers, that share - # storage & have their own dtype. This is safe because the param - # dtype size is always <= grad dtype size. - self.param_buffers = [] - for gbuf_index, grad_buffer in enumerate(self.grad_buffers): - size_ratio = torch.finfo(grad_buffer.dtype).bits // torch.finfo(params_dtype).bits - assert ( - size_ratio >= 1 - ), "param_dtype size should be smaller than or equal to grad_dtype size" - current_param_buffers = [] - for bucket in grad_buffer.buckets: - param_buffer = bucket.data.view(dtype=params_dtype) - param_buffer = param_buffer[: bucket.data.numel()] - assert ( - param_buffer.data_ptr() == bucket.data.data_ptr() - ), "param_buffer and grad_buffer for same bucket should start at the same byte address" - assert ( - param_buffer.numel() == bucket.data.numel() - ), "param_buffer and grad_buffer for same bucket should have the same number of elements" - current_param_buffers.append(param_buffer) - self.param_buffers.append(current_param_buffers) - # Now construct data structures to manage all-gather handles. self.all_gather_handles = [] self.all_gather_handle_index_to_bucket_index_map = [] self.model_index_to_all_gather_handle_index_map = {} self.all_gather_handle_indices = [] self.param_to_all_gather_handle_index_map = {} - self.param_buffer_copied = [] self.pbuf_view_items = self.get_model_param_buffer_dp_views() for (gbuf_index, dtype, bucket_index, _, _) in self.pbuf_view_items: @@ -501,9 +490,8 @@ def __init__( all_gather_handle_index ) - for param in self.grad_buffers[gbuf_index].buckets[bucket_index].params_list: + for param in self.buffers[gbuf_index].buckets[bucket_index].params_list: self.param_to_all_gather_handle_index_map[param] = all_gather_handle_index - self.param_buffer_copied.append(False) self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map) self.overlap_param_gather = overlap_param_gather @@ -702,7 +690,7 @@ def get_parameter_state(self): for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): # Compute local DP contiguous shard's size. - gbuf_world_numel = self.grad_buffers[gbuf_idx].buckets[bucket_idx].data.numel() + gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel() assert gbuf_world_numel % data_parallel_world_size == 0 gbuf_local_numel = gbuf_world_numel // data_parallel_world_size local_shards = { @@ -848,7 +836,7 @@ def load_parameter_state_from_state_dict(self, state_dict): for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): # Compute local DP contiguous shard's size. - gbuf_world_numel = self.grad_buffers[gbuf_idx].buckets[bucket_idx].data.numel() + gbuf_world_numel = self.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel() assert gbuf_world_numel == self.per_bucket_numel[gbuf_idx][dtype][bucket_idx] assert gbuf_world_numel % data_parallel_world_size == 0 gbuf_local_numel = gbuf_world_numel // data_parallel_world_size @@ -1016,7 +1004,7 @@ def get_model_param_buffer_dp_views(self): view_items = [] for gbuf_index, buffers in enumerate(self.param_buffers): view_items_per_model_chunk = [] - dtype = self.grad_buffers[gbuf_index].dtype + dtype = self.buffers[gbuf_index].param_dtype for bucket_index, buf in enumerate(buffers): data_parallel_world_size = torch.distributed.get_world_size( self.data_parallel_group @@ -1061,9 +1049,6 @@ def _dispatch_gather_model_params(self, all_gather_handle_index, force_sync=Fals bucket_index, ) - if not async_op: - self._copy_params_from_param_buffer(all_gather_handle_index) - def _make_forward_pre_hook(self): """ Create a forward pre-hook to wait on all-gather handles when necessary (i.e., @@ -1122,42 +1107,6 @@ def _finish_param_sync_helper(self, all_gather_handle_index): if next_all_gather_handle_index < self.num_all_gather_handles: self._dispatch_gather_model_params(next_all_gather_handle_index) - # Also check if we have already copied from the param buffer for this - # handle; if not, complete the copy and mark as such. - if not self.param_buffer_copied[all_gather_handle_index]: - self._copy_params_from_param_buffer(all_gather_handle_index) - self.param_buffer_copied[all_gather_handle_index] = True - - def _copy_params_from_param_buffer(self, all_gather_handle_index): - """ - Copy params from param_buffer to model_params. - """ - (gbuf_index, dtype, bucket_index) = self.all_gather_handle_index_to_bucket_index_map[ - all_gather_handle_index - ] - grad_buffer = self.grad_buffers[gbuf_index] - - if self.update_successful: - # Copy from param buffer to each param. - param_map = grad_buffer.param_index_map - for param, (buf_start, buf_end, bucket_index_in_param_map) in param_map.items(): - if bucket_index == bucket_index_in_param_map: - bucket_offset = grad_buffer.buckets[bucket_index].offset - param_buf = self.param_buffers[gbuf_index][bucket_index] - # buf_start and buf_end store position of this parameter in the full grad_buffer, - # so need to adjust these indices (by subtracting out bucket_offset) since we - # have independent param_bufs for each bucket. - param_buf_shard = param_buf[buf_start - bucket_offset : buf_end - bucket_offset] - assert param.data.nelement() == param_buf_shard.nelement() - param.view(-1).detach().copy_(param_buf_shard) - - # Zero out the grad buffer in preparation for next set of fwd / bwd passes after copy - # completes (since param_buffer and grad_buffer are shared for each bucket). - param_buf = self.param_buffers[gbuf_index][bucket_index] - grad_buf = grad_buffer.buckets[bucket_index].data - assert param_buf.data_ptr() == grad_buf.data_ptr() - grad_buf.zero_() - def _collect_main_grad_data_for_unscaling(self): """ Note: this should be equivalent to the float-16 optimizer's method, @@ -1267,7 +1216,6 @@ def copy_group_params(model_groups, shard_main_groups): def _reset_metadata_and_sync_gather_all_model_params(self, force_sync): # Reset metadata needed to track results of all-gathers. self.all_gather_handles = [None for _ in range(len(self.all_gather_handles))] - self.param_buffer_copied = [False for _ in range(len(self.param_buffer_copied))] # Launch synchronous all-gather if --overlap-param-gather is turned on or if force_sync # is explicitly set to True (e.g., if we are going to turn off all-gather overlapping for diff --git a/megatron/training.py b/megatron/training.py index dc9b34ecf3..e988ccd2ab 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -526,10 +526,7 @@ def train_step(forward_step_func, data_iterator, # Set grad to zero. for model_chunk in model: - # If using distributed optimizer, don't zero buffer here; zeroing of buffer is - # handled automatically by the optimizer after all-gathers finish. - # Otherwise, zero the buffer. - model_chunk.zero_grad_buffer(zero_buffer=(not args.use_distributed_optimizer)) + model_chunk.zero_grad_buffer() optimizer.zero_grad() # Forward pass.