From ec49222cd8434fcfc3206fc8579aa8a6819cc6e4 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 6 Feb 2024 13:50:39 -0800 Subject: [PATCH 01/27] Update nv-accelerate to latest torch (#5040) nv-accelerate were failing because accelerate required an older version. ``` Collecting torch<2.2.0,>=1.10.0 (from accelerate==0.27.0.dev0) ``` This updates us back to the latest. Previous test in log [here](https://github.com/microsoft/DeepSpeed/actions/runs/7799250328/job/21269656349#step:7:41), and new test in log [here](https://github.com/microsoft/DeepSpeed/actions/runs/7803089462/job/21282030153#step:7:41). --- .github/workflows/nv-accelerate-v100.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 96d0ef3c9e28..93286b62610a 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" From b42a470615453ed23c406d362eba074e36eb8126 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Wed, 7 Feb 2024 20:50:04 +0200 Subject: [PATCH 02/27] HPU Accelerator: fix supported_dtypes API (#5094) was returning bfloat16 instead of half --- accelerator/hpu_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 120e038dd227..eba4a50ad807 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -147,7 +147,7 @@ def is_fp16_supported(self): def supported_dtypes(self): supported_dtypes = [torch.float, torch.bfloat16] if self.is_fp16_supported(): - supported_dtypes.append(torch.bfloat16) + supported_dtypes.append(torch.half) return supported_dtypes # Misc From 4f477328c411270cf378a2318bc4f51c512ad2c8 Mon Sep 17 00:00:00 2001 From: minchao Date: Thu, 8 Feb 2024 03:02:37 +0800 Subject: [PATCH 03/27] [NPU] replace 'cuda' with get_accelerator().device_name() (#5095) Replace 'cuda' with `get_accelerator().device_name()` to support other accelerators. I searched the whole repo trying to fix same issue and this seems to be the only one. --- deepspeed/runtime/zero/stage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b76b781346e7..8f1827892500 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -198,7 +198,7 @@ def __init__( # backup fused_adam optimizer init if self.offload_optimizer and self.partial_offload != 1.0: - backup_gpu_tensor = torch.randn(1, device='cuda').to(self.dtype) + backup_gpu_tensor = torch.randn(1, device=get_accelerator().device_name()).to(self.dtype) backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' self.backup_optimizer = FusedAdam([backup_gpu_param], From 961bc85624174e8ca8ee7626b3f3b53c6c768085 Mon Sep 17 00:00:00 2001 From: mmhab <132277730+mmhab@users.noreply.github.com> Date: Wed, 7 Feb 2024 21:23:08 +0200 Subject: [PATCH 04/27] optimize clip_grad_norm_ function (#4915) Optimize clip_grad_norm_ function by removing .item() calls to reduce wait time for the device on the host. Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt --- deepspeed/runtime/utils.py | 45 +++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 82f200fccf9f..d7a35b7dbbe9 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -363,44 +363,53 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) - max_norm = float(max_norm) norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for p in parameters: + all_norms.append(p.grad.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + origin_device = total_norm.device.type + total_norm = total_norm.to(get_accelerator().device_name()) # Take max across all GPUs. if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item()**norm_type + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) else: - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type - + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) + if len(all_norms) > 0: + total_norm = torch.stack(all_norms).square().sum().float() + else: + total_norm = torch.FloatTensor([0.0]).to(parameters[0].device) + origin_device = total_norm.device.type + total_norm = total_norm.to(get_accelerator().device_name()) # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + total_norm = total_norm.pow(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) + scaled_norm_tensor = scaled_norm - scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) - total_norm = scaled_norm_tensor.item() + total_norm = scaled_norm_tensor + total_norm = total_norm.to(origin_device) + max_norm = torch.tensor([float(max_norm)], device=parameters[0].device) clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1: - for p in parameters: - p.grad.data.mul_(clip_coef) + tmp_tensor = torch.tensor([1.0], device=parameters[0].device) + clip_coef = torch.max(tmp_tensor, clip_coef) + for p in parameters: + p.grad.data.mul_(clip_coef) return total_norm From 688239e3f24f7ba11d3fe90bbe9670b7a61e5440 Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Wed, 7 Feb 2024 11:49:05 -0800 Subject: [PATCH 05/27] [xs] fix ZEROPP convergence test (#5061) there may be chances where the dataset shard loaded contains example with empty text `''` which will make the test fail (which occurred on my end) so fixing by dropping the empty examples --------- Co-authored-by: Michael Wyatt Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- tests/unit/runtime/zero/test_zeropp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/runtime/zero/test_zeropp.py b/tests/unit/runtime/zero/test_zeropp.py index 545ed98ad2ef..7a05c2a8001b 100644 --- a/tests/unit/runtime/zero/test_zeropp.py +++ b/tests/unit/runtime/zero/test_zeropp.py @@ -202,7 +202,7 @@ def load_and_prepare_data(self, model_name): tokenizer.pad_token = tokenizer.eos_token # Load and tokenize dataset - dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]') + dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]').filter(lambda x: x["text"]) def tokenize_function(examples): # Tokenize and ensure 'labels' are the same as 'input_ids' From 3255569b785f4561a0bce7913f7f00c1f6fde0c6 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:16:34 -0800 Subject: [PATCH 06/27] Switch hasattr check from compile to compiler (#5096) torch.compile introduced in torch 2.0 but torch.compiler was introduced in torch 2.1, this fixes issues for those building with torch 2.0 like the A6000 builds. ``` >>> torch.__version__ '2.1.0+cu121' >>> hasattr(torch, "compile") True >>> hasattr(torch, "compiler") True >>> torch.__version__ '2.0.0+cu117' >>> hasattr(torch, "compile") True >>> hasattr(torch, "compiler") False ``` --- deepspeed/runtime/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 603f563fca60..b2b612c85180 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -13,7 +13,7 @@ def is_compile_supported(): - return hasattr(torch, "compile") + return hasattr(torch, "compiler") def disable(func): From 697f945a05fd75eb3ef299309ee36a6283df2c98 Mon Sep 17 00:00:00 2001 From: BacharL Date: Thu, 8 Feb 2024 18:03:56 +0200 Subject: [PATCH 07/27] Split is_synchronized_device api to multiple apis (#5026) Deepspeed currently calls is_synchronized_device() to decide how to use the device. HPU does not fit into this definition since it behaves like all streams are blocking streams, meaning they preserve order between each other but asynchronous to CPU. see cudaStreamCreateWithFlags. **has_data_dependency_resolving()** HPU device is considered synchronized wrt CPU. Operations executed in the script order regardless of stream they were enqueued on. Tensor data is guaranteed to be valid. No need to stream dependencies or CPU synchronizations. **use_host_timers()** HPU device execution is async. To measure device execution time we must use device timers. **has_memory_backpressure()** limiting number of inflight fetched params and number of inflight grad reduce_scatter calls is not necessary since HPU will stop enqueuing calls if memory is full, creating internal backpressure for the CPU until memory is available. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- accelerator/abstract_accelerator.py | 12 ++++++++++++ accelerator/cpu_accelerator.py | 9 +++++++++ accelerator/cuda_accelerator.py | 9 +++++++++ accelerator/hpu_accelerator.py | 9 +++++++++ accelerator/npu_accelerator.py | 9 +++++++++ accelerator/xpu_accelerator.py | 9 +++++++++ .../zero/partitioned_param_coordinator.py | 4 ++-- deepspeed/runtime/zero/stage3.py | 16 ++++++++-------- deepspeed/runtime/zero/stage_1_and_2.py | 8 +++++--- deepspeed/utils/timer.py | 2 +- 10 files changed, 73 insertions(+), 14 deletions(-) diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 8ba652a8de92..3c5d799e293e 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -17,6 +17,18 @@ def __init__(self): def is_synchronized_device(self): ... + @abc.abstractmethod + def use_host_timers(self): + ... + + @abc.abstractmethod + def resolves_data_dependency(self): + ... + + @abc.abstractmethod + def handles_memory_backpressure(self): + ... + # Device APIs @abc.abstractmethod def device_name(self, device_index): diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index a02777f5223b..b1aba75b4c5e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -21,6 +21,15 @@ def __init__(self): def is_synchronized_device(self): return True + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): return 'cpu' diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index c6285e8bc97a..2030f36631e9 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -42,6 +42,15 @@ def _init_pynvml(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index eba4a50ad807..30b115e8b1ab 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -30,6 +30,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return False + + def resolves_data_dependency(self): + return True + + def handles_memory_backpressure(self): + return True + def device_name(self, device_index=None): if device_index is None: return 'hpu' diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 4e20445d9d32..472157e32c02 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -28,6 +28,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index c3d6630cc235..3f65263946ab 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -19,6 +19,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index == None: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index cfeae9e7839a..8fc962c4f2a7 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -308,13 +308,13 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: self.__inflight_param_registry.pop(param).wait() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().handles_memory_backpressure(): event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) self.__profiler.stop_event(wait_event_name, wait_numel) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8f1827892500..d971092ebd17 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1087,7 +1087,7 @@ def independent_gradient_partition_epilogue(self): self.__reduce_and_partition_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() for param_id in self.params_already_reduced.keys(): @@ -1231,7 +1231,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param): @instrument_w_nvtx @torch.no_grad() def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() <= self.reduce_bucket_size: @@ -1280,7 +1280,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: self.params_in_ipg_bucket.clear() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().handles_memory_backpressure(): event = get_accelerator().Event() event.record() self.param_reduce_events.append(event) @@ -2153,7 +2153,7 @@ def has_overflow(self, partition_gradients=True): overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) self.inf_or_nan_tracker.zero_() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -2224,7 +2224,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: """get fp32 gradient partition dictionary accessed as grad_dict[parameter_group_index][parameter_index] """ - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() grad_dict = collections.defaultdict(dict) if self.offload_optimizer: @@ -2254,7 +2254,7 @@ def get_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() if self.offload_optimizer: @@ -2266,7 +2266,7 @@ def get_fp32_grad_for_param(self, param) -> Tensor: return self._fp32_state_allgather(param, fp32_grad) def _get_fp32_opt_state_partition(self, param, optim_state_key=None): - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] @@ -2323,7 +2323,7 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() if self.offload_optimizer: diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 56607b349ae7..48044c0161a2 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -744,7 +744,8 @@ def independent_gradient_partition_epilogue(self): self.params_already_reduced[i] = False if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear previously reduced grads of other partitions self._clear_previous_reduced_grads() @@ -1020,7 +1021,7 @@ def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, di def average_tensor(self, tensor): if self.overlap_comm: stream = self.reduction_stream - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): stream.wait_stream(get_accelerator().current_stream()) else: stream = get_accelerator().current_stream() @@ -1501,7 +1502,8 @@ def _clear_previous_reduced_grads(self): def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): process_group = self.dp_process_group if process_group is None else process_group if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions self._clear_previous_reduced_grads() stream = self.reduction_stream diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index 4aac4ded1243..11ef54fe4665 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -50,7 +50,7 @@ def __init__(self, name): self.name_ = name self.started_ = False self.event_timers = [] - self.use_host_timer = get_accelerator().is_synchronized_device() + self.use_host_timer = get_accelerator().use_host_timers() self.start_event = None self.elapsed_records = None self.start_time = 0.0 From 3c811c966bc2cbabdd4b097946d09c05b176beee Mon Sep 17 00:00:00 2001 From: Heyang Qin Date: Thu, 8 Feb 2024 10:53:08 -0800 Subject: [PATCH 08/27] 47% FastGen speedup for low workload - refactor allocator (#5090) This PR refactor FastGen allocator and add caching for empty_from method ![image](https://github.com/microsoft/DeepSpeed/assets/46639297/16d1f736-358c-4224-b88f-a708eef8f3a4) DS Master: Deployment: Mixtral-8x7B-v0.1-tp4-b768 Clients: 1, Prompt (mean): 500 tokens, Generation (mean): 1024 tokens, Query throughput: 0.075 queries/s, Token throughput (total): **163.130 tokens/s**, Query latency: 13.310 s, Token generation latency: 0.020 s/token, First token received: 0.055 s This PR: Deployment: Mixtral-8x7B-v0.1-tp4-b768-allocator-rework Clients: 1, Prompt (mean): 500 tokens, Generation (mean): 1024 tokens, Query throughput: 0.095 queries/s, Token throughput (total): **240.386 tokens/s**, Query latency: 10.472 s, Token generation latency: 0.016 s/token, First token received: 0.056 s --------- Co-authored-by: Olatunji Ruwase --- deepspeed/inference/v2/allocator.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/deepspeed/inference/v2/allocator.py b/deepspeed/inference/v2/allocator.py index bebdcf83aee3..fcc0d94c0f82 100644 --- a/deepspeed/inference/v2/allocator.py +++ b/deepspeed/inference/v2/allocator.py @@ -5,17 +5,27 @@ from functools import reduce from typing import Iterable - +from collections import defaultdict import torch from deepspeed.accelerator import get_accelerator -def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: - shape_size = reduce(lambda x, y: x * y, shape) - if shape_size == 0: - raise ValueError("Cannot create empty tensor with size 0") - return tensor.flatten()[:shape_size].view(shape) +class Allocator: + cache = defaultdict(dict) + + def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: + try: + return Allocator.cache[tensor][shape] + except KeyError: + shape_size = reduce(lambda x, y: x * y, shape) + if shape_size == 0: + raise ValueError("Cannot create empty tensor with size 0") + Allocator.cache[tensor][shape] = tensor.flatten()[:shape_size].view(shape) + return Allocator.cache[tensor][shape] + + +empty_from = Allocator.empty_from def on_device(method) -> torch.Tensor: From 2518cc429d51b1371e63e1aeecd22fd92c9e89e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Storhaug?= Date: Thu, 8 Feb 2024 21:57:57 +0100 Subject: [PATCH 09/27] Support `exclude_frozen_parameters` for `zero_to_fp32.py` script (#4979) Adds support for the `zero_to_fp32.py` script to merge only the trainable parameters through a new argument `only_trainable_params`. Fixes #3437 --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/utils/zero_to_fp32.py | 36 ++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 49b846633d6e..24cc342e78d1 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -211,9 +211,11 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -326,7 +328,8 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -335,7 +338,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero2_merge_frozen_params(state_dict, zero_model_states) + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -444,7 +448,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -453,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -465,7 +471,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -474,6 +480,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters Returns: - pytorch ``state_dict`` @@ -511,10 +518,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. @@ -523,9 +530,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag= - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) print(f"Saving fp32 state dict to {output_file}") torch.save(state_dict, output_file) @@ -584,9 +592,13 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): type=str, default=None, help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) From 0a10bd427e035cbd185c2d44346996e8c1a0b42d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 9 Feb 2024 09:20:22 -0800 Subject: [PATCH 10/27] Fix alignment of optimizer states when loading (#5105) The ZeRO 1/2 optimizer pads optimizer states according to NCCL's alignment. However, it does not account for NCCL's alignment when loading from an elastic checkpoint, resulting in improperly restored optimizer states. The existing test case only verifies parameter groups and fails to catch this specific issue. This PR addresses the misalignment and enhances the unit test to ensure that optimizer state tensors are correctly matched post-restoration. --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- tests/unit/checkpoint/common.py | 13 +++++++++++++ tests/unit/checkpoint/test_zero_optimizer.py | 8 ++++---- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 48044c0161a2..3e579422b26d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2199,7 +2199,7 @@ def refresh_fp32_params(self): # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) - alignment = dist.get_world_size(group=self.real_dp_process_group[group_id]) + alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[group_id]) if torch.is_tensor(all_partition_states[0]): flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id) diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index d6dda2f14cbe..7442e51bad5d 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -96,6 +96,19 @@ def compare_state_dicts(state0, state1, expected_mismatch_keys=[]): assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' +def compare_opt_state_dicts(state0, state1, expected_mismatch_keys=[]): + for param_group0, saved_param_group1 in zip(state0['param_groups'], state1['param_groups']): + compare_state_dicts(param_group0, saved_param_group1, expected_mismatch_keys) + + assert "state" in state0 + assert "state" in state1 + assert len([state0["state"].keys()]) == len([state1["state"].keys()]) + + for (k0, s0), (k1, s1) in zip(state0["state"].items(), state1["state"].items()): + assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' + compare_state_dicts(s0, s1, expected_mismatch_keys) + + def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index f2237341ef68..0b9efb3ec462 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -246,7 +246,8 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.backward(loss) model.step() if load_optim: - torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, 'opt-state-dict')) + opt_state_dict_file = f'opt-state-dict_rank{dist.get_rank()}' + torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, opt_state_dict_file)) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -256,10 +257,9 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file)) curr_sd = model.optimizer.optimizer.state_dict() - for curr_param_group, saved_param_group in zip(curr_sd['param_groups'], saved_sd['param_groups']): - compare_state_dicts(curr_param_group, saved_param_group, expected_mismatch_keys) + compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): From d04a8386d1823326ff2e01442823310a9f9f0c5b Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:42:46 -0800 Subject: [PATCH 11/27] Skip Triton import for AMD (#5110) When testing DeepSpeed inference on an `AMD Instinct MI250X/MI250` GPU, the `pytorch-triton-rocm` module would break the `torch.cuda` device API. To address this, importing `triton` is skipped when the GPU is determined to be `AMD`. This change allows DeepSpeed to be executed on an AMD GPU w/o kernel injection in the DeepSpeedExamples [text-generation example](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation) using the following command: ```bash deepspeed --num_gpus 1 inference-test.py --model facebook/opt-125m ``` TODO: Root-cause the interaction between `pytorch-triton-rocm` and DeepSpeed to understand why this is causing the `torch.cuda` device API to break. --- deepspeed/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a6feae06974b..6c7aa8b15ef9 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -12,10 +12,14 @@ from torch.optim.lr_scheduler import _LRScheduler from packaging import version as pkg_version -try: - import triton # noqa: F401 # type: ignore - HAS_TRITON = True -except ImportError: +# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed +if not (hasattr(torch.version, 'hip') and torch.version.hip is not None): + try: + import triton # noqa: F401 # type: ignore + HAS_TRITON = True + except ImportError: + HAS_TRITON = False +else: HAS_TRITON = False from . import ops From 25a02047ae1726f2c67fba75d262e1e1afdfe306 Mon Sep 17 00:00:00 2001 From: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:43:54 -0800 Subject: [PATCH 12/27] Add HIP conversion file outputs to .gitignore (#5111) This PR adds the following HIP output files to `.gitignore`: ``` *_hip.cpp *_hip.h *.hip *.cuh *hip_layers.h ``` --------- Co-authored-by: Michael Wyatt --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index e284c4fd35a1..db6790886cb4 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,10 @@ docs/code-docs/build ## Testing data # Saved checkpoints for testing tests/unit/saved_checkpoint/ + +# HIP files created during AMD compilation +*_hip.cpp +*_hip.h +*.hip +*.cuh +*hip_layers.h From 18179807f5fa5dbb93b7af51d51f9002e1539478 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Sat, 10 Feb 2024 17:47:45 -0800 Subject: [PATCH 13/27] Remove optimizer step on initialization (#5104) All ZeRO 1/2/3 stages call the optimizer's `step()` on its initialization. This increments a counter in the optimizer and produces a different result in parameter update with the normal usage of PyTorch. This PR eliminates `step()` in the initialization and lazily configures some internal states (linking *hp_params*) after the first `step()` call. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/bf16_optimizer.py | 16 +++- deepspeed/runtime/zero/stage3.py | 4 - deepspeed/runtime/zero/stage_1_and_2.py | 70 +++++++++++----- deepspeed/utils/__init__.py | 2 +- deepspeed/utils/mixed_precision_linkage.py | 10 ++- deepspeed/utils/tensor_fragment.py | 17 ++-- tests/unit/runtime/zero/test_zero.py | 36 ++++++--- .../runtime/zero/test_zero_tensor_fragment.py | 81 +++++++++---------- 8 files changed, 141 insertions(+), 95 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0aefd1946c36..6c1ae345ebb6 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,7 +18,7 @@ align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) -from deepspeed.utils import link_hp_params, fragment_address +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -165,6 +165,7 @@ def _setup_for_real_optimizer(self): # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._hp_optimizer_states_linked = False self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -199,9 +200,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. @@ -215,8 +222,6 @@ def initialize_optimizer_states(self): param_partition.grad = grad_partition.to( param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition - self.optimizer.step() - if self.grad_acc_dtype is not torch.float32: for param_partition in self.fp32_groups_flat_partition: param_partition.grad = None @@ -263,6 +268,9 @@ def step(self, closure=None): self.optimizer.step() + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + self.update_lp_params() self.clear_hp_grads() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d971092ebd17..42008236a9ea 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1015,10 +1015,6 @@ def initialize_optimizer_states(self): else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) - # Initialize the optimizer states with the flattened fp32 partition. - if not is_adagrad: - self._optimizer_step(i) - if swappable_param_subgroup: self._partitioned_params_swap_out(i) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 3e579422b26d..18b58403f1d7 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,7 +28,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -88,6 +88,12 @@ def _get_padded_tensor(src_tensor, size): return padded_tensor +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) + return padded_tensor + + class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -536,6 +542,8 @@ def __init__(self, see_memory_usage(f"After initializing ZeRO optimizer", force=True) self._link_all_hp_params() + self._hp_optimizer_states_linked = False + self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -578,9 +586,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def is_moe_group(self, group): return 'moe' in group and group['moe'] @@ -664,8 +678,6 @@ def initialize_optimizer_states(self): # which do lazy initialization of the state at the first call to step. if isinstance(self.optimizer, torch.optim.Adagrad): self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) - else: - self.optimizer.step() if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: @@ -1793,6 +1805,9 @@ def _optimizer_step(self, group_no): self.optimizer.step() self.optimizer.param_groups = original_param_groups + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + def step(self, closure=None): """ Not supporting closure. @@ -2208,19 +2223,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group # Assume non-tensor states are not partitioned and equal across ranks, so return first one return all_partition_states[0] - def _restore_base_optimizer_state(self, base_optimizer_group_states): + def _restore_step_from_elastic_checkpoint(self, all_state_dict): + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + for sd in all_state_dict), "State dicts of all partitions must have the same step value" + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): if type(base_optimizer_group_states) == dict: base_optimizer_group_states = base_optimizer_group_states['state'] + + saved_keys = base_optimizer_group_states[0].keys() + for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - dst_tensor = self.optimizer.state[p][key] - src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) - self.optimizer.state[p][key].data.copy_(src_tensor.data) + padding = 0 if group_paddings is None else group_paddings[i] + for key in saved_keys: + saved = base_optimizer_group_states[i][key] + + if torch.is_tensor(saved): + if key in self.optimizer.state[p]: + dst_tensor = self.optimizer.state[p][key] + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) + self.optimizer.state[p][key].data.copy_(src_tensor.data) + else: + self.optimizer.state[p][key] = _pad_tensor_by_size( + saved, padding, torch.float32, + torch.device('cpu') if self.cpu_offload else self.device) else: self.optimizer.state[p][key] = saved + for param_group in self.optimizer.param_groups: + param_group['step'] = base_optimizer_state_step + def get_ep_ranks(self, rank=0, group_name=None): from deepspeed.utils import groups expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) @@ -2248,15 +2283,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) base_optimizer_group_states.append(partition_states) - self._restore_base_optimizer_state(base_optimizer_group_states) - - # Restore step - if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]: - assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for sd in all_state_dict), "State dicts of all partitions must have the same step value" - loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for param_group in self.optimizer.param_groups: - param_group['step'] = loaded_param_groups_step + self._restore_base_optimizer_state(base_optimizer_group_states, + self._restore_step_from_elastic_checkpoint(all_state_dict), None) def load_state_dict(self, state_dict_list, @@ -2368,7 +2396,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._restore_elastic_base_optimizer_state(state_dict_list) else: # loading an elastic checkpoint into rigid exec - self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], + current_rank_sd[GROUP_PADDINGS]) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 1f86306aefec..33ea8ba60818 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -17,6 +17,6 @@ from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter -from .mixed_precision_linkage import link_hp_params +from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index b1afa8f00aa3..7dea6ba322db 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -9,13 +9,19 @@ def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group): + param_group_index, partition_start, partition_size, dp_group): local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group) for lp_param, lp_start in local_lp_param_and_offset: lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, - partition_start, partition_size, partition_optimizer_state) + partition_start, partition_size) + + +def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state): + for lp in lp_param_list: + if lp._hp_mapping is not None: + lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition]) def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 5f94070dc4c7..49eefafcfbcc 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -21,11 +21,11 @@ class tensor_fragment: lp_fragment_address: fragment_address hp_fragment: torch.Tensor hp_fragment_address: fragment_address - optim_fragment: Dict gradient_dict: Dict offload_gradient_dict: Dict use_offload: bool param_group_index: int + optim_fragment: Dict = None def update_hp(self): self.hp_fragment.data.copy_(self.lp_fragment.data) @@ -39,6 +39,13 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def set_optim_state_fragment(self, flat_hp_partition, optim_fragment): + self.optim_fragment = { + key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel) + for key, value in optim_fragment.items() + if torch.is_tensor(value) and value.shape == flat_hp_partition.shape + } + def get_hp_fragment_address(self): return self.hp_fragment_address @@ -255,7 +262,7 @@ def safe_set_local_fp32_param(param, value): def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, optimizer_state_dict): + param_group_index, partition_start, partition_size): lp_end = lp_param.numel() + lp_start hp_start = partition_start hp_end = partition_start + partition_size @@ -268,11 +275,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict fragment_numel = fragment_end - fragment_start hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) - optim_fragment = { - key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) - for key, value in optimizer_state_dict.items() - if torch.is_tensor(value) and value.shape == flat_hp_partition.shape - } lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) @@ -281,7 +283,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, - optim_fragment=optim_fragment, gradient_dict=gradient_dict, offload_gradient_dict=offload_gradient_dict, use_offload=use_offload, diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index bc31e3b9a968..2594d910acff 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1370,6 +1370,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest): world_size = 1 def test(self, zero_stage): + # We verify trhee conditions: + # 1. global_steps starts at 0 + # 2. All subgroups have the same step count + # 3. The global step count is the same as the step count of the first subgroup + # force all params to be partitioned by forcing threshold=0 config_dict = { "train_micro_batch_size_per_gpu": 2, @@ -1399,24 +1404,31 @@ def test(self, zero_stage): model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) - for i, batch in enumerate(data_loader): + assert model.global_steps == 0 + + for batch in data_loader: loss = model(batch[0], batch[1]) model.backward(loss) + + is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary() model.step() - step_counts = [] - if zero_stage == 3: - for sub_group_id, _ in enumerate(optimizer.fp16_groups): - fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] - state = optimizer.optimizer.state[fp32_param] - step_counts.append(state["step"]) - assert all(step == step_counts[0] for step in step_counts) - elif zero_stage == 1 or zero_stage == 2: - for param_group in optimizer.optimizer.param_groups: - for param in param_group["params"]: - state = optimizer.optimizer.state[param] + if is_gradient_accumulation_boundary: + step_counts = [] + + if zero_stage == 3: + for sub_group_id, _ in enumerate(optimizer.fp16_groups): + fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] + state = optimizer.optimizer.state[fp32_param] step_counts.append(state["step"]) + elif zero_stage == 1 or zero_stage == 2: + for param_group in optimizer.optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.optimizer.state[param] + step_counts.append(state["step"]) + assert all(step == step_counts[0] for step in step_counts) + assert model.global_steps == step_counts[0] @pytest.mark.parametrize("zero_stage", [1, 2, 3]) diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index c223e67af697..b3adfdf96c50 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -24,35 +24,26 @@ SECOND_ORDER_KEY = 'exp_avg_sq' -def validate_full_tensors(model): +def validate_tensor(model, api_type, opt_states): + assert api_type in ["full", "local"] for _, lp in model.named_parameters(): - hp = safe_get_full_fp32_param(lp) - exp_avg = safe_get_full_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_full_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_full_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] - if lp.requires_grad: - assert all([p is not None for p in param_list]) + param_list = [] + if opt_states: + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg')) + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg_sq') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg_sq')) else: - assert all([p is None for p in param_list]) - - -def validate_local_tensors(model): - for _, lp in model.named_parameters(): - hp = safe_get_local_fp32_param(lp) - exp_avg = safe_get_local_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_local_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_local_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] + param_list.append(safe_get_full_fp32_param(lp) if api_type == "full" else safe_get_local_fp32_param(lp)) + param_list.append(safe_get_full_grad(lp) if api_type == "full" else safe_get_local_grad(lp)) if lp.requires_grad: assert all([p is not None for p in param_list]) else: assert all([p is None for p in param_list]) -validate_funcs_mapping = {"full": validate_full_tensors, "local": validate_local_tensors} - - class MyModel(torch.nn.Module): def __init__(self, hidden_dim, frozen_weights): @@ -71,12 +62,10 @@ def forward(self, x, y): for l in self.linears: x = l(x) x = self.act(x) - loss = self.cel(x, y) - val = (x, loss) - return val + return self.cel(x, y) -def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): +def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_after_bwd, validate_after_step): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -86,10 +75,10 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): dist.barrier() for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) - loss = loss[1] model.backward(loss) - validate_func(model) + validate_after_bwd(model) model.step() + validate_after_step(model) # Needed in ZeRO 3. Not doing so can give memory leak model.destroy() @@ -147,9 +136,10 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz else: model = MyModel(hidden_dim, frozen_weights) - validate_func = validate_funcs_mapping[api_type] + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_func) + run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_after_bwd, validate_after_step) def test_bf16_fragments(self, frozen_weights): if frozen_weights: @@ -178,7 +168,12 @@ def test_bf16_fragments(self, frozen_weights): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_full_tensors) + + api_type = "full" + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) + + run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_after_bwd, validate_after_step) def create_random_values(model, key_list, group, use_cuda=True): @@ -315,23 +310,21 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp if zero_stage == 3: config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim with deepspeed.zero.Init(config_dict_or_path=config_dict): - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) else: - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) world = dist.get_world_size() group = dist.new_group(ranks=list(range(world))) dist.barrier() - optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] - helper_funcs = helper_funcs_mapping[api_type] - optim_state_values = helper_funcs["create_random_values"](model, - optim_keys, - group, - use_cuda=offload_device == OffloadDeviceEnum.none) - helper_funcs["set_param_values_with_dict"](model, optim_state_values) - helper_funcs["validate_param_values_with_dict"](model, optim_state_values) - - # Needed in ZeRO 3. Not doing so can leak memory. - model.destroy() + + def validate_func(model): + optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"]( + model, optim_keys, group, use_cuda=offload_device == OffloadDeviceEnum.none) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + run_fragmented_model(model, config_dict, hidden_dim, dtype, lambda _: None, validate_func) From 41bc9feed00d5893803a66d1657e684fde67b795 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 12 Feb 2024 09:29:01 -0800 Subject: [PATCH 14/27] Update version.txt after 0.13.2 release (#5119) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.13.2 Author - @mrwyattii Co-authored-by: mrwyattii --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index 9beb74d490bc..288adf538f0e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.13.2 +0.13.3 From f295aea09eb38aace6ac2155a452bbba7df57aa3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:26:59 -0800 Subject: [PATCH 15/27] Stop tracking backward chain of broadcast (ZeRO3) (#5113) The broadcast that ZeRO3 uses on initialization displays a warning shown below. This PR avoids this by passing `.data` to the broadcast. The same issue of ZeRO 1/2 was addressed in #5075 using `torch.no_grad`, which affects multiple lines in the scope. This PR also changes the fix for ZeRO 1/2 to passing `.data` to broadcast for consistency and safety. ``` /home/mtanaka/.conda/envs/tcomp/lib/python3.9/site-packages/torch/autograd/graph.py:681: UserWarning: c10d::broadcast_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /opt/conda/conda-bld/pytorch_1704786093577/work/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.) ``` --- deepspeed/runtime/engine.py | 21 +++++++++---------- .../runtime/zero/partition_parameters.py | 7 ++++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f0602813f3ab..e1bf0e167665 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1065,17 +1065,16 @@ def is_replicated(p): return False return True - with torch.no_grad(): - for p in self.module.parameters(): - # Broadcast the model for different parameters - if is_moe_param(p): - if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, - groups._get_expert_broadcast_src_rank(p.group_name), - group=self.expert_data_parallel_group[p.group_name]) - else: - if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) + for p in self.module.parameters(): + # Broadcast the model for different parameters + if is_moe_param(p): + if torch.is_tensor(p) and is_replicated(p): + dist.broadcast(p.data, + groups._get_expert_broadcast_src_rank(p.group_name), + group=self.expert_data_parallel_group[p.group_name]) + else: + if torch.is_tensor(p) and is_replicated(p): + dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 99a9d100082b..5cf655d8741a 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1008,9 +1008,10 @@ def _update_persist_config(self, ds_config): def _zero_init_param(self, param): self._convert_to_deepspeed_param(param) if dist.get_world_group() == self.get_dp_process_group(): - dist.broadcast(param, 0, self.get_dp_process_group()) + dist.broadcast(param.data, 0, self.get_dp_process_group()) else: - dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group()) + dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0), + self.get_dp_process_group()) param.partition() def _convert_to_zero_parameters(self, param_list): @@ -2177,7 +2178,7 @@ def __exit__(self, *exc): self.params[0].partition(param_list=self.params, has_been_updated=False) return - handles = [dist.broadcast(p, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] for h in handles: h.wait() self.params[0].partition(param_list=self.params, has_been_updated=True) From 0d9cfa01e2947df25631aceac22af56bb361e9b3 Mon Sep 17 00:00:00 2001 From: Mingzhen WANG Date: Tue, 13 Feb 2024 04:28:20 +0800 Subject: [PATCH 16/27] [NPU]ZeRO-Infinity feature compatibility (#5077) NPU hardware platform does not support `O_DIRECT`. Surrounded by `#if defined`, it will have no impact on other hardware platforms such as GPUs --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- csrc/aio/common/deepspeed_aio_common.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index 32b0e8a32394..0f2895dfa328 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -268,6 +268,10 @@ void report_file_error(const char* filename, const std::string file_op, const in int open_file(const char* filename, const bool read_op) { const int flags = read_op ? (O_RDONLY | O_DIRECT) : (O_WRONLY | O_CREAT | O_DIRECT); +#if defined(__ENABLE_CANN__) + int* flags_ptr = (int*)&flags; + *flags_ptr = read_op ? (O_RDONLY) : (O_WRONLY | O_CREAT); +#endif const int mode = 0600; const auto fd = open(filename, flags, mode); if (fd == -1) { From 3c0bd3128866e5313e8745ed4a49f19ebd130bbf Mon Sep 17 00:00:00 2001 From: Max Kovalenko <75629718+deepcharm@users.noreply.github.com> Date: Mon, 12 Feb 2024 22:28:49 +0200 Subject: [PATCH 17/27] BF16 optimizer: Improve device utilization by immediate grad update (#4975) Enabled gradient accumulation in bf16 optimizer which updates fp32 gradients once they are available. This improves device utilization on some back-ends, by parallelizing the workload across engines. To enable the feature (disabled by default), use a new config flag "immediate_grad_update" under "bf16" section in Deepspeed config.json (default is false). Example: "bf16": { "enabled": true, "immediate_grad_update": true } --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/bf16_optimizer.py | 73 ++++++++++++++++++++++------- deepspeed/runtime/config.py | 9 ++++ deepspeed/runtime/constants.py | 4 ++ deepspeed/runtime/engine.py | 3 +- 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 6c1ae345ebb6..aaa836bf1c31 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -39,7 +39,8 @@ def __init__(self, dp_process_group=None, timers=None, grad_acc_dtype=None, - graph_harvesting=False): + graph_harvesting=False, + immediate_grad_update=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -50,6 +51,7 @@ def __init__(self, assert grad_acc_dtype in [torch.float32, torch.bfloat16 ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" self.grad_acc_dtype = grad_acc_dtype + self.immediate_grad_update = immediate_grad_update self.clip_grad = clip_grad self.norm_type = norm_type @@ -163,6 +165,9 @@ def _setup_for_real_optimizer(self): self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) + if self.immediate_grad_update: + self.create_grad_acc_hooks() + # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() self._hp_optimizer_states_linked = False @@ -291,27 +296,37 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg self.update_hp_grads(clear_lp_grads=clear_lp_grads) @torch.no_grad() - def update_hp_grads(self, clear_lp_grads=False): + def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): + if lp.grad is None: + return - def _update_hp_grads_func(clear_lp_grads=False): - for i, group in enumerate(self.bf16_groups): - for j, lp in enumerate(group): - if lp.grad is None: - continue - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad - self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad._zero() + hp_grad = self.fp32_groups_gradients[group_idx][param_idx] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]' + + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[group_idx][param_idx] = True + + # clear gradients + if clear_lp_grads: + lp.grad._zero() + + @torch.no_grad() + def _update_hp_grads_func(self, clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + self._update_hp_grad(lp, i, j, clear_lp_grads) + + @torch.no_grad() + def update_hp_grads(self, clear_lp_grads=False): + if self.immediate_grad_update: + return if self.graph_harvesting: - graph_process(False, _update_hp_grads_func, clear_lp_grads) + graph_process(False, self._update_hp_grads_func, clear_lp_grads) else: - _update_hp_grads_func(clear_lp_grads) + self._update_hp_grads_func(clear_lp_grads) #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): @@ -449,6 +464,28 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): + assert self.immediate_grad_update + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) + + def create_grad_acc_hooks(self): + self.grad_accs = [] + for i, param_group in enumerate(self.bf16_groups): + for j, param in enumerate(param_group): + if param.requires_grad: + + def wrapper(param, i, j): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + def accumulate_hp_grads_and_remove_lp(*notneeded): + self.accumulate_hp_grads_and_remove_lp(param, i, j) + + grad_acc.register_hook(accumulate_hp_grads_and_remove_lp) + self.grad_accs.append(grad_acc) + + wrapper(param, i, j) + def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 20fbf475ca90..975fb1f21501 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -169,6 +169,14 @@ def get_bfloat16_enabled(param_dict): return False +def get_bfloat16_immediate_grad_update(param_dict): + for key in [BFLOAT16, BFLOAT16_OLD]: + if key in param_dict.keys(): + return get_scalar_param(param_dict[key], BFLOAT16_IMMEDIATE_GRAD_UPDATE, + BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT) + return False + + def get_fp16_master_weights_and_grads_enabled(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) @@ -818,6 +826,7 @@ def _initialize_params(self, param_dict): self.fp16_enabled = get_fp16_enabled(param_dict) self.fp16_auto_cast = get_fp16_auto_cast(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + self.bfloat16_immediate_grad_update = get_bfloat16_immediate_grad_update(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 82d8a0557a41..679230ca7d4c 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -126,6 +126,10 @@ BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False +# BFLOAT16 optimizer immediate gradient update +BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" +BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False + ######################################### # FP16 support ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e1bf0e167665..fe0fa8ad793d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1477,7 +1477,8 @@ def _configure_bf16_optimizer(self, optimizer): dp_process_group=self.seq_data_parallel_group, timers=timers, grad_acc_dtype=self.get_data_types()[1], - graph_harvesting=self.graph_harvesting()) + graph_harvesting=self.graph_harvesting(), + immediate_grad_update=self._config.bfloat16_immediate_grad_update) return optimizer From 9aaead552087c539a2eef178c2e6eeb399d47718 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Mon, 12 Feb 2024 23:13:37 +0100 Subject: [PATCH 18/27] removed if condition in `if collate_fn is None` (#5107) remove the if condition for value of `collate_fn`, because `collate_fn` is `None` by default in `DataLoader` Co-authored-by: Michael Wyatt Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../data_pipeline/data_sampling/data_analyzer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index cb0d366ce798..d7c6b82590b0 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -136,15 +136,12 @@ def run_map_helper(self, thread_id): f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) - if self.collate_fn is None: - iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False)) - else: - iterator = iter( - DataLoader(thread_dataset, - batch_sampler=sampler, - num_workers=0, - collate_fn=self.collate_fn, - pin_memory=False)) + iterator = iter( + DataLoader(thread_dataset, + batch_sampler=sampler, + num_workers=0, + collate_fn=self.collate_fn, + pin_memory=False)) if self.custom_map_init is None: metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types, self.metric_dtypes, self.save_path, self.worker_id) From d67d4e5b2f178822689f883f01075b8c916aba84 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 12 Feb 2024 14:41:45 -0800 Subject: [PATCH 19/27] disable compile tests for torch<2.1 (#5121) Tests running older version of torch will fail the compile tests added in #4878. --- tests/unit/runtime/compile/test_compile_wrapper.py | 4 ++++ tests/unit/runtime/compile/test_compile_zero.py | 4 ++++ tests/unit/runtime/compile/test_load_config.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index fbf235fb7d62..98a7c28c6a28 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -8,9 +8,13 @@ import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import required_torch_version from unit.common import DistributedTest +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + @pytest.fixture def base_config(): diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 87e3c52b9e3c..910f32db1c96 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -7,11 +7,15 @@ import torch from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.utils import required_torch_version from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest from unit.util import bf16_required_version_check +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + class TestZeRO(DistributedTest): world_size = 2 diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 351e91d2f69b..5f1c01b86852 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -9,9 +9,13 @@ from unit.simple_model import SimpleModel import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import required_torch_version from unit.common import DistributedTest +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + custom_backend_called = False custom_compler_fn_called = False From d532f643654043cdd168db0c57f918f90ac8b805 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 13 Feb 2024 11:45:34 -0800 Subject: [PATCH 20/27] Update inference test model names (#5127) --- tests/unit/inference/test_inference.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 067a4969869f..f3056a225a9b 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -36,16 +36,16 @@ pytest.skip("skip inference tests on rocm for now", allow_module_level=True) _bert_models = [ - "bert-base-cased", - "bert-base-uncased", - "bert-large-cased", - "bert-large-uncased", - "bert-base-multilingual-cased", - "bert-base-multilingual-uncased", + "google-bert/bert-base-cased", + "google-bert/bert-base-uncased", + "google-bert/bert-large-cased", + "google-bert/bert-large-uncased", + "google-bert/bert-base-multilingual-cased", + "google-bert/bert-base-multilingual-uncased", "deepset/minilm-uncased-squad2", "cross-encoder/ms-marco-MiniLM-L-12-v2", "dslim/bert-base-NER", - "bert-large-uncased-whole-word-masking-finetuned-squad", + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", "distilbert/distilbert-base-cased-distilled-squad", ] _roberta_models = [ From 9c69662032c381edede22446a3b0bef57febef14 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Tue, 13 Feb 2024 23:32:57 +0100 Subject: [PATCH 21/27] Fix issue with zero-sized file after merging file on curriculum `map_reduce` (#5106) In `deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py` when calling `merge_file_` , the following operation may not flush the merged file in time, before it's needed: ``` # Concatenate data with open(data_file_path(another_file), 'rb') as f: shutil.copyfileobj(f, self._data_file) ``` this leads to `self._data_file` having size zero, and later to the following error (with stack trace): ``` File "~/my_code/deepspeed_trainer.py", line 999, in my_func data_analyzer.run_reduce() File "~/my_env/lib/python3.11/site-packages/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py", line 413, in run_reduce self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path, File "~/my_env/lib/python3.11/site-packages/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py", line 371, in merge_map_results index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/my_env/lib/python3.11/site-packages/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py", line 486, in __init__ self._do_init(path, skip_warmup) File "~/my_env/lib/python3.11/site-packages/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py", line 502, in _do_init self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/my_env/lib/python3.11/site-packages/numpy/core/memmap.py", line 268, in __new__ mm = mmap.mmap(fid.fileno(), bytes, access=acc, offset=start) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ValueError: cannot mmap an empty file ``` This PR fixes that issue by forcing the destination file to be flushed and adding an assert to make sure the concatenation succeeded. deepspeed version: '0.13.2' --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../data_pipeline/data_sampling/data_analyzer.py | 14 ++++++++------ .../data_pipeline/data_sampling/indexed_dataset.py | 2 ++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index d7c6b82590b0..c9957bb9ea3c 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -84,12 +84,14 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_functions, metric_results): + def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): - metric_type, metric_function, metric_result = metric_types[m_idx], \ - metric_functions[m_idx], metric_results[m_idx] + metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ + metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] + metric_values = metric_function(data) + assert metric_values.numpy().dtype == metric_dtype, \ + f"dtype {metric_values.numpy().dtype} returned by metric_function {metric_function} is not consistent with the metric_dtype {metric_dtype}" if metric_type == 'single_value_per_sample': - metric_values = metric_function(data) for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( @@ -102,7 +104,6 @@ def update_metric_results(self, data, metric_types, metric_functions, metric_res writer.writerows([metric_result["metric_to_sample_dict"][m_value]]) metric_result["metric_to_sample_dict"][m_value] = [] elif metric_type == 'accumulate_value_over_samples': - metric_values = metric_function(data) if metric_result["metric_value"] is None: metric_result["metric_value"] = metric_values else: @@ -155,7 +156,8 @@ def run_map_helper(self, thread_id): try: data = next(iterator) if self.custom_map_update is None: - self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results) + self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results) else: self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) processed_sample += self.batch_size diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 60115fa6efef..facf85dbe7b4 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -609,6 +609,8 @@ def merge_file_(self, another_file): # Concatenate data with open(data_file_path(another_file), 'rb') as f: shutil.copyfileobj(f, self._data_file) + self._data_file.flush() + assert os.stat(self._data_file.name).st_size != 0, f"Zero-sized file: {self._data_file.name}" def finalize(self, index_file): self._data_file.close() From 083197ea1be39cf6781da4f8005eacc2ee40819e Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:17:46 -0800 Subject: [PATCH 22/27] Update return codes in PyTest to properly error out if tests fail (#5122) Update PyTest return codes to better reflect the [documented exit codes](https://docs.pytest.org/en/7.1.x/reference/exit-codes.html). This was causing issues in the nv-inference test. Build with nv-inference with just the returncode changes, showing that as [now failing](https://github.com/microsoft/DeepSpeed/actions/runs/7880331865/job/21502182220) with CUDA OOM errors. Sample previous passing build [here](https://github.com/microsoft/DeepSpeed/actions/runs/7719051411/job/21041568991). --- .github/workflows/nv-inference.yml | 2 +- tests/unit/common.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 579a638b567b..cc55136d1a4b 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -57,6 +57,6 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" + #pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6" pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" diff --git a/tests/unit/common.py b/tests/unit/common.py index 420db577cf09..76bebf6b725a 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -168,7 +168,7 @@ def _launch_daemonic_procs(self, num_procs): # Shortcut to exit pytest in the case of a hanged test. This # usually means an environment error and the rest of tests will # hang (causing super long unit test runtimes) - pytest.exit("Test hanged, exiting", returncode=0) + pytest.exit("Test hanged, exiting", returncode=1) # Tear down distributed environment and close process pools self._close_pool(pool, num_procs) @@ -204,7 +204,7 @@ def _launch_non_daemonic_procs(self, num_procs): if not any_done: for p in processes: p.terminate() - pytest.exit("Test hanged, exiting", returncode=0) + pytest.exit("Test hanged, exiting", returncode=1) # Wait for all other processes to complete for p in processes: From b20c46745b3d018cc9417818c9b26bd981163af1 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Wed, 14 Feb 2024 15:14:24 -0800 Subject: [PATCH 23/27] add missing methods to MPS_Accelerator (#5134) #5026 introduced new abstract methods for the base accelerator class. These methods were not defined for `MPS_Accelerator`. Fixes #5132 --- accelerator/mps_accelerator.py | 9 ++++ tests/unit/accelerator/test_accelerator.py | 59 ++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/unit/accelerator/test_accelerator.py diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index f6303cf9890f..972b33caece1 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -24,6 +24,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py new file mode 100644 index 000000000000..964cf2b24f4e --- /dev/null +++ b/tests/unit/accelerator/test_accelerator.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +import os +import sys +import importlib +import re + +import deepspeed + +DS_ACCEL_PATH = "deepspeed.accelerator" +IGNORE_FILES = ["abstract_accelerator.py", "real_accelerator.py"] + + +@pytest.fixture +def accel_class_name(module_name): + class_list = [] + mocked_modules = [] + + # Get the accelerator class name for a given module + while True: + try: + module = importlib.import_module(module_name) + break + except ModuleNotFoundError as e: + # If the environment is missing a module, mock it so we can still + # test importing the accelerator class + missing_module = re.search(r"\'(.*)\'", e.msg).group().strip("'") + sys.modules[missing_module] = lambda x: None + mocked_modules.append(missing_module) + for name in dir(module): + if name.endswith("_Accelerator"): + class_list.append(name) + + assert len(class_list) == 1, f"Multiple accelerator classes found in {module_name}" + + yield class_list[0] + + # Clean up mocked modules so as to not impact other tests + for module in mocked_modules: + del sys.modules[module] + + +@pytest.mark.parametrize( + "module_name", + [ + DS_ACCEL_PATH + "." + f.rstrip(".py") for f in os.listdir(deepspeed.accelerator.__path__[0]) + if f.endswith("_accelerator.py") and f not in IGNORE_FILES + ], +) +def test_abstract_methods_defined(module_name, accel_class_name): + module = importlib.import_module(module_name) + accel_class = getattr(module, accel_class_name) + accel_class.__init__ = lambda self: None + _ = accel_class() From 3e5d4004732bd2fd91a44c2e8ec5a4e95da6f7e2 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 15 Feb 2024 07:52:59 +0100 Subject: [PATCH 24/27] Solve tensor vs numpy dtype conflicts in data efficiency map-reduce. (#5108) The map-reduce types are a mess. By looking at the file `deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py`, we see that the reduce only accepts numpy types due to the following check: ``` dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float64, 7: np.double, 8: np.uint16, 9: np.uint32, 10: np.uint64 } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) ``` Now the issue is that python and torch types are not equal (in python) for the same representation: ``` > type(int) == type(np.int64) True > type(torch.int64) == type(np.int64) False ``` And the user-specified `metric_function` needs to return a tensor, so it will have automatically have a torch type. If the user does not specify a tensor, then this fails: In `deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py`: ``` def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): [...] if metric_type == 'single_value_per_sample': metric_values = metric_function(data) for row in range(metric_values.size()[0]): ``` Only a `torch.Tensor` has the `.size()` attribute: ``` > np.array([1,2,3]).size() TypeError: 'int' object is not callable > torch.tensor([1,2,3]).size() torch.Size([3]) ``` So to my understanding: the user must create a `DataAnalyser` with a `metric_dtypes` which is of a numpy dtype, yet needs to provide a `metric_function` function that returns a torch dtype that **must match the same data type as numpy**, e.g. ``` def metric_functions(int_list): return torch.tensor(int_list).as(torch.int64). #<-- TORCH type required here data_analyzer = DataAnalyzer( dataset=train_dataset, metric_names=["seqlen"], metric_functions=[metric_functions], metric_types=['single_value_per_sample'], metric_dtypes=[np.int64], ### <--- NUMPY type required here ) ``` Finally there's no datatype check, so if a user forgets to add `.as(torch.int64)` to the `metric_functions`, then the files output by threads will be called e.g. `seqlen/worker0_thread0/seqlen_metric_to_sample_730.0.csv` as the integer `730` is defaulted to `float`. This would later fail as the reduce step would look for `seqlen/worker0_thread0/seqlen_metric_to_sample_730.csv` instead. This PR adds support to both `np.ndarray` and `torch.tensor` return dtypes on function `metric_function`. When dealing with tensors, it converts to the corresponding numpy dtype before outputting. It also adds several `asserts` to make sure use provides the correct return type and dtype on `metric_function` and `metric_dtype`, respectively. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../data_sampling/data_analyzer.py | 19 ++++++---- .../data_sampling/indexed_dataset.py | 38 +++++++++---------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index c9957bb9ea3c..1522a3d94226 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -13,7 +13,7 @@ from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset from deepspeed.utils import logger -from .indexed_dataset import MMapIndexedDataset +from .indexed_dataset import MMapIndexedDataset, valid_dtypes from .utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype @@ -61,9 +61,7 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp for m_idx in range(len(metric_names)): metric_name, metric_type, metric_dtype = metric_names[m_idx], \ metric_types[m_idx], metric_dtypes[m_idx] - assert metric_dtype not in [ - np.float64, np.double - ], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)." + assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}" metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/" os.makedirs(metric_save_path, exist_ok=True) if metric_type == 'single_value_per_sample': @@ -89,8 +87,14 @@ def update_metric_results(self, data, metric_types, metric_dtypes, metric_functi metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] metric_values = metric_function(data) - assert metric_values.numpy().dtype == metric_dtype, \ - f"dtype {metric_values.numpy().dtype} returned by metric_function {metric_function} is not consistent with the metric_dtype {metric_dtype}" + + assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \ + "metric_function must return a tensor or array" + assert metric_values.dtype == metric_dtype, \ + f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}" + if isinstance(metric_values, np.ndarray): + metric_values = torch.from_numpy(metric_values) + if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) @@ -159,7 +163,8 @@ def run_map_helper(self, thread_id): self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, metric_results) else: - self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) + self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results) processed_sample += self.batch_size duration = (time.time() - start) / 3600.0 remain_duration = duration * total_sample / processed_sample - duration diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index facf85dbe7b4..7a6963bc27eb 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -98,25 +98,26 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) +# valid metric_dtypes as numpy and torch types dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float64, - 7: np.double, - 8: np.uint16, - 9: np.uint32, - 10: np.uint64 + 1: (np.uint8, torch.uint8), + 2: (np.int8, torch.int8), + 3: (np.int16, torch.int16), + 4: (np.int32, torch.int32), + 5: (np.int64, torch.int64), + 6: (np.uint16, None), + 7: (np.uint32, None), + 8: (np.uint64, None), } +valid_dtypes = set([dt[0] for dt in dtypes.values()] + [dt[1] for dt in dtypes.values() if dt[1] is not None]) + def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: - return k - raise ValueError(dtype) + for c, (np_dt, torch_dt) in dtypes.items(): + if dtype in [np_dt, torch_dt]: + return c + raise ValueError(f"{dtype} not supported. Supported types: {valid_dtypes}") def index_file_path(prefix_path): @@ -153,7 +154,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(' Date: Thu, 15 Feb 2024 11:34:41 +0100 Subject: [PATCH 25/27] Fix broadcast deadlock for incomplete batches in data sample for data analysis (#5117) When the batch is not a full batch (`drop_last=False`), then the size of the current batch is smaller than the expected: ``` self.global_batch_size = self.micro_batch_times_data_parallel_size * self.gradient_accumulation_steps ``` The `get_next_global_batch()` method will try to broadcast the tensor of a size smaller than `self.global_batch_size` from a master rank (`0`). However, in this case, the master rank will send a shorter tensor. This leads to an unexpected behaviour (deadlock, crash, or `None` tensor on receiving ranks). The documentation for the [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) operation says "tensor must have the same number of elements in all processes participating in the collective." In the following call, `tensor` can have different sizes when comparing master with other participant ranks. File `deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py`, like `289`: ``` dist.broadcast(batch, 0, group=self.data_parallel_group) ``` This PR fixes that bug, by filling incomplete batch indices with `-1` so that the batch tensor is always of the same size. Note: an alternative resolution is to broadcast beforehand the size of the batches tensor, but adds an extra comm step. The current method of extending the `batch` tensor with `-1`s is memory-safe as the batch tensor will match the one used in previous iterations with a full batch. --- .../runtime/data_pipeline/data_sampling/data_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py index ef845e4bc490..088540c0ab4c 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py @@ -281,12 +281,17 @@ def get_next_global_batch(self): for cidx in range(len(samples_per_cluster)): batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx]) self.np_rng.shuffle(batch) + + # broadcast tensor must have same shape across participants. So we fill batch with -1s when not full + assert len(batch) <= self.global_batch_size + batch += [-1] * (self.global_batch_size - len(batch)) batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1) else: batch = torch.empty(self.global_batch_size, device=get_accelerator().current_device_name(), dtype=torch.long) dist.broadcast(batch, 0, group=self.data_parallel_group) + batch = batch[batch != -1] # remove trailing -1s used to fill incomplete batch tensor self.batch = batch.tolist() def __iter__(self): From a7864846a4c7c9c7c0a4ab463293d595767bf4b8 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 15 Feb 2024 14:14:57 +0100 Subject: [PATCH 26/27] Avoid zero-sized microbatches for incomplete minibatches when doing curriculum learning (#5118) Related to curriculum learning and the data efficiency module. The `get_start_end_idx()` method that is used to compute which batch indices to allocate across data parallel ranks, assumes the batch to be of size `micro-batch size * data_parallel_size` and allocates sequential subsets of indices across data loader processes. When `drop_last=False`, then the global batch size will very likely be smaller than `micro-batch size * data_parallel_size`, and `get_start_end_idx()` will give a full `self.microbatch_size` sized batch to a few initial nodes and the remaining ones will have a zero-sized microbatch. This leads to load imbalance and (probably) wrong updates as gradients are averaged across different microbatch sizes. This PR fixes that by distributing the same amount (+-1 sample) across all data loader ranks, when batch is not complete. --------- Co-authored-by: Conglong Li --- .../data_pipeline/data_sampling/data_sampler.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py index 088540c0ab4c..100bef3f7946 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py @@ -119,9 +119,15 @@ def set_custom_curriculum_learning_schedule(self, schedule_func_dict): if metric in schedule_func_dict: self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric]) - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.micro_batch_size - end_idx = start_idx + self.micro_batch_size + def get_start_end_idx(self, batch_len=None): + """ + given the length of a minibatch (defaults to micro-batch size * data_parallel_size), + return the start and end indices of the current data parallel rank + """ + batch_len = batch_len or self.micro_batch_times_data_parallel_size + start_idx_fn = lambda r: round(r * batch_len / self.data_parallel_group.size()) + start_idx = start_idx_fn(self.data_parallel_rank) + end_idx = start_idx_fn(self.data_parallel_rank + 1) return start_idx, end_idx def get_sample_based_on_metric_value(self, metric, value_start, value_end): @@ -302,7 +308,7 @@ def __iter__(self): self.batch = self.batch[self.micro_batch_times_data_parallel_size:] if len(current_batch) == self.micro_batch_times_data_parallel_size or \ (len(current_batch) > 0 and not self.drop_last): - start_idx, end_idx = self.get_start_end_idx() + start_idx, end_idx = self.get_start_end_idx(len(current_batch)) yield current_batch[start_idx:end_idx] self.consumed_samples += len(current_batch) current_batch = [] From 2d0a6bc20ae20c37085ab5bb22db2f0e50725400 Mon Sep 17 00:00:00 2001 From: Bruno Magalhaes Date: Thu, 15 Feb 2024 16:28:04 +0100 Subject: [PATCH 27/27] remove mandatory `index` key from output of `metric_function` in `DataAnalysis` map operation (#5112) When performing the map operation required for the curriculum learning, the output of `metric_function` requires an `index` field: ``` def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): for m_idx in range(len(metric_types)): [...] if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) metric_result["metric_to_sample_dict"][metric_values[row].item()].append( data['index'][row][0].item()). ##<------- data['index']?? ``` There is no mention to this in the documentation, where it specifies that the output of `metric_function` should be a dict/DataFrame (?) with an `index` key/column. To makes things worse, on top of that, there is no way for an user to be able to specify a proper `index` value for each sample, because the distribution of samples across workers/threads is not know, as it's done inside `DataAnalysis`: ``` def run_map_helper(self, thread_id): start_idx, end_idx = self.thread_splits[thread_id][0], \ self.thread_splits[thread_id][1] logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \ f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) ``` Since by design you picked a `SequentialSampler`, then you know beforehand the global id of each each sample of each batch of each thread of each worker by looking at ``` self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id, self.num_threads) start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1] ``` and you can populate that index value correctly, instead of asking the user to provide it. This PR removes the need for `'index'` key in `data` and uses instead the batch, thread, and worker ids to compute the global index of each sample. --- .../data_sampling/data_analyzer.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index 1522a3d94226..f8ce0e3fa2bf 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -36,7 +36,8 @@ def __init__(self, custom_map_init=None, custom_map_update=None, custom_map_finalize=None, - custom_reduce=None): + custom_reduce=None, + sample_indices=None): super().__init__() self.dataset = dataset self.num_workers = num_workers @@ -55,6 +56,7 @@ def __init__(self, self.custom_map_update = custom_map_update self.custom_map_finalize = custom_map_finalize self.custom_reduce = custom_reduce + self.sample_indices = sample_indices def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id): metric_results = [] @@ -82,7 +84,13 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_dtypes, metric_functions, metric_results): + def update_metric_results(self, + data, + metric_types, + metric_dtypes, + metric_functions, + metric_results, + batch_start_idx=0): for m_idx in range(len(metric_types)): metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] @@ -97,9 +105,13 @@ def update_metric_results(self, data, metric_types, metric_dtypes, metric_functi if metric_type == 'single_value_per_sample': for row in range(metric_values.size()[0]): + sample_idx = batch_start_idx + row # sample idx following dataset iteration order + if 'index' in data: # Megatron use case, sample idx provided in 'index' field + sample_idx = data['index'][row][0].item() + elif self.sample_indices is not None: # user defined shuffling of indices + sample_idx = self.sample_indices[sample_idx] metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) - metric_result["metric_to_sample_dict"][metric_values[row].item()].append( - data['index'][row][0].item()) + metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx) for m_value in metric_result["metric_to_sample_dict"]: if len(metric_result["metric_to_sample_dict"][m_value]) > 100: metric_fname = metric_result["metric_to_sample_fname"] @@ -159,12 +171,13 @@ def run_map_helper(self, thread_id): while True: try: data = next(iterator) + batch_start_idx = start_idx + processed_sample if self.custom_map_update is None: self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, - metric_results) + metric_results, batch_start_idx) else: self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions, - metric_results) + metric_results, batch_start_idx) processed_sample += self.batch_size duration = (time.time() - start) / 3600.0 remain_duration = duration * total_sample / processed_sample - duration