diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 76583c129cb9..2089d59dbce4 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -20,6 +20,7 @@ "stage": [0|1|2], "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, + "stage3_use_all_reduce_for_fetch_params": [true|false], "allgather_partitions": [true|false], "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, @@ -234,6 +235,12 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): this option is enabled and then saves the fp16 model weights. """ + use_all_reduce_for_fetch_params: bool = Field(False, alias="stage3_use_all_reduce_for_fetch_params") + """ + Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing + the overhead of concatenation and slicing on the host. + """ + stage3_gather_fp16_weights_on_model_save: bool = Field(False, deprecated=True, new_param="gather_16bit_weights_on_model_save") diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index c8099791f882..b2909145c50c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -23,7 +23,7 @@ from deepspeed.utils import groups import deepspeed -from ..utils import see_memory_usage +from ..utils import see_memory_usage, get_only_unique_item from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -716,6 +716,31 @@ def wait(self) -> None: handle.wait() +class AllReduceCoalescedHandle: + + def __init__(self, handle, params: List[Parameter]) -> None: + self.handle = handle + self.params = params + self.complete = False + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.complete: + return + + instrument_w_nvtx(self.handle.wait)() + + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + class QuantizationInfo: # a placeholder object to store all quant related vars used in handles def __init__(self) -> None: @@ -1003,6 +1028,11 @@ def __init__( if not self.use_all_gather_into_tensor: logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}") + self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig, + "use_all_reduce_for_fetch_params") + if _ds_config is not None: + self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params + def _update_persist_config(self, ds_config): Init.apply_param_persistence = True Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold @@ -1250,75 +1280,99 @@ def all_gather_coalesced(params: Iterable[Parameter], return AllGatherHandle(handle, param, quantization=quant_info) else: - if not quantize: - dtype_params = defaultdict(list) - for p in params: - dtype_params[p.ds_tensor.dtype].append(p) - handles = [] - for dtype, params in dtype_params.items(): - handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) + if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: + # Use all_reduce instead of all_gather to fetch the module params + flat_buffer_size = sum(p.ds_numel_aligned for p in params) + flat_tensor = torch.zeros(flat_buffer_size, + dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), + device=get_accelerator().current_device_name(), + requires_grad=False) + start_param = 0 + for param in params: + param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) + start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) - return MultipleAllGatherHandles(handles) + start_param += param.ds_numel + handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) + + return AllReduceCoalescedHandle(handle=handle, params=params) else: - partition_sz = sum(p.ds_tensor.ds_numel for p in params) + if not quantize: + dtype_params = defaultdict(list) + for p in params: + dtype_params[p.ds_tensor.dtype].append(p) + handles = [] + for dtype, params in dtype_params.items(): + handles.append( + _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) - if use_secondary_tensor: - partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + return MultipleAllGatherHandles(handles) - flat_tensor = torch.empty(partition_sz * world_size, - dtype=torch.int8, - device=get_accelerator().current_device_name(), - requires_grad=False) - - if use_secondary_tensor: - if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params - ]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) - for p in params - ]) - else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params - ])) else: - if hasattr(params[0].ds_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)( - [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params - ]) + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups + for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) + for p in params + ]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()) + for p in params + ])) else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)( - [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) - quant_scale_buffer = torch.empty( - scales.numel() * world_size, - dtype=torch.float32, - device=get_accelerator().current_device_name(), - requires_grad=False, - ) - handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) - quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) - quant_info = QuantizationInfo() - quant_info.quantized_param = flat_tensor - quant_info.backend = self.quantizer_module - quant_info.quant_handle = quant_handle - quant_info.scale_buffer = quant_scale_buffer - quant_info.partition_sz = partition_sz - quant_info.world_size = world_size - return AllGatherCoalescedHandle( - allgather_handle=handle, - params=params, - partitions=None, - world_size=world_size, - use_secondary_tensor=use_secondary_tensor, - quantization=quant_info, - ) + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + ) def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param @@ -1554,6 +1608,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): param.ds_tensor.ds_numel = partition_size param.ds_tensor.status = PartitionedParamStatus.AVAILABLE param.ds_tensor.final_location = final_location + param.ds_numel_aligned = tensor_size start = partition_size * self.get_partition_rank() end = start + partition_size