From 5fb71c0a189408151c3a8968a7025fb61d94950b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Sat, 26 Oct 2024 02:26:47 +0800 Subject: [PATCH 01/11] sequence parallel for uneven heads (#6392) In sequence_parallel (Ulysses), the sequence parallel size is constrained by the requirement to be divisible by the number of heads, which prevents some models/workloads from setting a specific sequence parallel size. This PR implements uneven all-to-all heads splitting. - both support batch first (b,s,...) and seq_len first(s,b..) layout. - Added unit tests with numerical checks. Locally also tested with **7 heads with sp=4** and **20 heads with sp=8**, and it passed. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Ma, Guokai Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/module_inject/tp_shard.py | 4 +- deepspeed/sequence/layer.py | 126 ++++++++++++++++++ deepspeed/utils/groups.py | 2 + .../unit/sequence_parallelism/test_ulysses.py | 84 +++++++++++- 4 files changed, 214 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 6758c7a657f6..57be0c793856 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -24,7 +24,9 @@ def set_n_embd(num): def get_num_kv_heads(): global num_kv_heads - return num_kv_heads + if 'num_kv_heads' in globals(): + return num_kv_heads + return None def get_num_attention_heads(): diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e809fe1118b5..8fd3ca918433 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,6 +10,8 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size_list, set_num_kv_heads, get_num_kv_heads +from deepspeed.utils import groups def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): @@ -38,8 +40,132 @@ def post_func(input): return post_func +def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + assert batch_dim_idx in [0, 1], "batch_dim_idx must be either 0 or 1" + + if not (scatter_idx < 2): + input_splits = get_shard_size_list(inp_shape[scatter_idx], seq_world_size) + input = input.transpose(0, scatter_idx).contiguous() + local_heads = input_splits[groups._get_sequence_parallel_rank()] + output_splits = [local_heads] * seq_world_size + + output_buffer_shape = [seq_world_size * local_heads] + list(input.shape[1:]) + output = torch.empty(output_buffer_shape, device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits,\ + input_split_sizes=input_splits,group=group) + ###[seq_ws*local_heads, ...] to [seq_ws, local_heads, ...] + output = output.view(seq_world_size, local_heads, *output.shape[1:]) + ###[seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + + ### batch_dim_idx=0 [seq_ws,local_heads,seq_len,b,...] to [b, seq_ws, seq_len, local_heads ...] + ### batch_dim_idx=1 [seq_ws,local_heads,b,seq_len,...] to [seq_ws,seq_len,b,local_heads,...] + if batch_dim_idx == 0: + order = [3, 0, 2, 1] + list(range(4, len(output.shape))) + output = output.permute(order).contiguous() + ###[b, seq_ws*local_seq_len, local_heads,...] + output = output.view(output.shape[0], inp_shape[gather_idx] * seq_world_size, + *output.shape[3:]).contiguous() + elif batch_dim_idx == 1: + output = output.transpose(1, 3).contiguous() + ###[seq_ws*local_seq_len, b, local_heads,...] + output = output.view(inp_shape[gather_idx] * seq_world_size, *output.shape[2:]).contiguous() + else: + # The compatibility handling of 4D and 3D tensors, standardizing to 3D. + input = input.reshape(input.shape[0], input.shape[1], -1) + + if batch_dim_idx == 0: #b,s,h + input = input.permute(1, 2, 0).contiguous() #s,h,b + elif batch_dim_idx == 1: #s,b,h + input = input.transpose(1, 2).contiguous() #s,h,b + seq_len, h, batch_size = input.shape + num_local_heads_list = get_shard_size_list(get_num_kv_heads(), seq_world_size) + local_heads = num_local_heads_list[groups._get_sequence_parallel_rank()] + h_dim = h // local_heads + local_seq_len = seq_len // seq_world_size + + input = input.view(seq_len * h, batch_size) + local_seq_len_with_heads = int(input.shape[0] / seq_world_size) # dim size of local_seq_len*local_heads*hdim + input_splits = [local_seq_len_with_heads] * seq_world_size + coeff = local_seq_len_with_heads // local_heads #per head: dim size of local_seq_len*hdim + + #uneven seq_world_size coeff, total_heads/local_heads. + heads_scale_coeff = get_num_kv_heads() / local_heads + + output_splits = [num_local_heads * coeff for num_local_heads in num_local_heads_list] + output_buff_d1_size = int(heads_scale_coeff * local_seq_len_with_heads) + total_h = int(inp_shape[gather_idx] * heads_scale_coeff) + output = torch.empty(output_buff_d1_size, input.shape[1], device=input.device, dtype=input.dtype) + dist.all_to_all_single(output,input,output_split_sizes=output_splits, \ + input_split_sizes=input_splits,group=group) + ################## + #suppose 7 heads divide into 4 ranks [2,2,2,1] + #chunk_num_heads_small=floor(7/4)=1 + #chunk_num_heads_large=ceil(7/4)=2 + #num_chunk_heads_large=len([2,2,2])=3, all2all_buffer_counts + #num_chunk_heads_small=len([1])=1, all2all_buffer_counts + #total_num_large_heads=sum([2,2,2])=7 + #total_num_small_heads=sum([1])=1 + + chunk_num_heads_small = get_num_kv_heads() // seq_world_size # even heads compatible + chunk_num_heads_large = chunk_num_heads_small + 1 + num_chunk_heads_large = get_num_kv_heads() % seq_world_size + num_chunk_heads_small = seq_world_size - num_chunk_heads_large + total_num_large_heads = num_chunk_heads_large * chunk_num_heads_large + total_num_small_heads = num_chunk_heads_small * chunk_num_heads_small + + heads_large_combine_size = coeff * total_num_large_heads + heads_small_combine_size = coeff * total_num_small_heads + heads_large_chunk, heads_small_chunk = output.split([heads_large_combine_size, heads_small_combine_size], + dim=0) + heads_large_chunk = heads_large_chunk.view(num_chunk_heads_large, local_seq_len, chunk_num_heads_large, h_dim, + batch_size) + heads_small_chunk = heads_small_chunk.view(num_chunk_heads_small, local_seq_len, chunk_num_heads_small, h_dim, + batch_size) + if batch_dim_idx == 0: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[batch,local_seq_len,all2all_buffer_counts*n_heads,dim] + order = [4, 1, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(batch_size, local_seq_len, + total_num_small_heads, h_dim) + elif batch_dim_idx == 1: + #[all2all_buffer_counts, local_seq_len, n_heads,dim,batch]->[local_seq_len,batch,all2all_buffer_counts*n_heads,dim] + order = [1, 4, 0, 2, 3] + heads_large_chunk = heads_large_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_large_heads, h_dim) + heads_small_chunk = heads_small_chunk.permute(order).contiguous().view(local_seq_len, batch_size, + total_num_small_heads, h_dim) + + output = torch.cat([heads_large_chunk, heads_small_chunk], dim=2).contiguous() + + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + output_shape= inp_shape[: gather_idx] + \ + [total_h,] + \ + inp_shape[gather_idx + 1:] + + output = output.view(output_shape) + + return output + + def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) + # we only need num_heads once + num_heads = input.shape[2] + + if get_num_kv_heads() is not None or num_heads % seq_world_size != 0: + # Assuming here that the number of heads for q is consistent with kv + # If not, additional logic is required for cases like GQA + if get_num_kv_heads() is None: + assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})" + # set heads at first call by num_total_heads. + # then use ``get_num_kv_heads() is not None`` to re-entry uneven path. + set_num_kv_heads(num_heads) + assert async_op == False, "uneven head sp does not support async op" + return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group) + if batch_dim_idx == 0: # b, s, n, h if scatter_idx < 2: diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 9dd288ef46db..e9550a0ec25a 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -484,6 +484,8 @@ def _get_sequence_parallel_rank(): global mpu if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): return mpu.get_sequence_parallel_rank() + if mesh_device is not None: + return dist.get_rank(mesh_device.get_group(mesh_dim="sequence_parallel")) return 0 diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index 915c89e0b00a..d9ed54322d5c 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -11,9 +11,12 @@ from unit.common import DistributedTest from deepspeed.sequence.layer import _SeqAllToAll from unit.util import skip_on_arch +from unit.simple_model import * +from deepspeed.utils import groups +from deepspeed.module_inject.tp_shard import get_shard_size_list +#Use mesh device to create data and sequence parallel group -#Use mesh device to create data and sequence parallel group class TestUlyssesUtils(DistributedTest): world_size = 4 @@ -75,3 +78,82 @@ def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_ # Check outputs are the same as input for i in range(1, len(outputs)): assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}" + + +@pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension +@pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension +@pytest.mark.parametrize("num_heads", [3, 7]) +@pytest.mark.parametrize("head_dim", [16]) +class TestUlyssesAll2All_odd(DistributedTest): + world_size = 4 + + def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None: + + data_parallel_size = 2 + seq_parallel_size = self.world_size // data_parallel_size + skip_on_arch(min_arch=8) + + def seq_batch_heads_hash(d0, d1, h, offset_d0=0, offset_d1=0, offset_h=0): + d0 += offset_d0 + d1 += offset_d1 + h += offset_h + return d0 * 10 + h + d1 * 0.1 + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + ds_engine, _, _, _ = initialize(model=model, + config_params={"train_batch_size": 8}, + mesh_param=(data_parallel_size, seq_parallel_size)) + + scatter_idx = 2 + outputs = [] + inputs = [] + batch_dims = [0, 1] + seq_dims = [1, 0] + + for idx, seq_dim in enumerate(seq_dims): + gather_idx = seq_dim + batch_dim_idx = batch_dims[idx] + + #4D tensor : b,s,h,d or s,b,h,d + #create a hash tensor from pos_id, head_id, and batch_id + d0_indices = torch.arange(d0).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(d1).reshape(1, -1, 1, 1) + h_indices = torch.arange(num_heads).reshape(1, 1, -1, 1) + input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device) + if batch_dim_idx == 1: #seq_len_dim : 0(d0) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, + d0 * groups._get_sequence_parallel_rank(), 0) + elif batch_dim_idx == 0: #seq_len_dim : 1(d1) + input_tensor[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, + d1 * groups._get_sequence_parallel_rank()) + inputs.append(input_tensor) + + ### first all2all: sequence parallel to head parallel + s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx, + batch_dim_idx) + + # s2h_tensor check for the first all2all: compare with the expected ground truth + d0_indices = torch.arange(s2h_tensor.shape[0]).reshape(-1, 1, 1, 1) + d1_indices = torch.arange(s2h_tensor.shape[1]).reshape(1, -1, 1, 1) + h_indices = torch.arange(s2h_tensor.shape[2]).reshape(1, 1, -1, 1) + shard_list = get_shard_size_list(num_heads, groups._get_sequence_parallel_world_size()) + head_offset = sum(shard_list[:groups._get_sequence_parallel_rank()]) + s2h_truth = torch.zeros_like(s2h_tensor) + s2h_truth[:] = seq_batch_heads_hash(d0_indices, d1_indices, h_indices, 0, 0, head_offset) + + assert torch.allclose(s2h_truth, + s2h_tensor), f"s2h_tensor differs from the expected for sequence dim: {seq_dim}" + #No op + ### second all2all: head parallel to sequence parallel + h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx, + batch_dim_idx) + print( + f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}' + ) + outputs.append(h2s_tensor) + + # Check outputs for the second all2all + for i in range(0, len(outputs)): + assert torch.allclose(inputs[i], + outputs[i]), f"[{dist.get_rank()}]Outputs differ for sequence dim {seq_dims[i]}" From 24285d6c73d3e505262a42c91a9d1ba1d9ece154 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:47:22 -0700 Subject: [PATCH 02/11] Add fallback for is_compiling (#6663) Importing `torch.compiler.is_compiling` causes an error with an older version of PyTorch. This PR adds a fallback for `is_compiling` to use an equivalent function of older PyTorch versions. This will resolve #6656. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/compiler.py | 13 +++++++++++++ deepspeed/utils/logging.py | 7 +++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 879c0a1a2cc9..fa9220f4fcd0 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -5,6 +5,15 @@ import torch +try: + from torch.compiler import is_compiling as torch_is_compiling +except ImportError: + try: + from torch._dynamo.external_utils import is_compiling as torch_is_compiling + except ImportError: + # Torch does not have compiler support + torch_is_compiling = lambda: False + def is_compile_supported(): return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") @@ -14,3 +23,7 @@ def disable(func): if is_compile_supported(): return torch.compiler.disable(func) return func + + +def is_compiling(): + return torch_is_compiling() diff --git a/deepspeed/utils/logging.py b/deepspeed/utils/logging.py index 55fb72f6c9fb..d5db29485db3 100644 --- a/deepspeed/utils/logging.py +++ b/deepspeed/utils/logging.py @@ -7,8 +7,7 @@ import logging import sys import os -import torch -from deepspeed.runtime.compiler import is_compile_supported +from deepspeed.runtime.compiler import is_compile_supported, is_compiling log_levels = { "debug": logging.DEBUG, @@ -26,7 +25,7 @@ def create_warning_filter(logger): def warn_once(record): nonlocal warn - if is_compile_supported() and torch.compiler.is_compiling() and not warn: + if is_compile_supported() and is_compiling() and not warn: warn = True logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to" " disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1") @@ -39,7 +38,7 @@ def logging_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if torch.compiler.is_compiling(): + if is_compiling(): return else: return func(*args, **kwargs) From 54903e09eb131bb7b69bfc154e3970d4958131b9 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:14:26 -0700 Subject: [PATCH 03/11] Update profiler registration check (#6668) Resolves #5432. --- deepspeed/profiling/flops_profiler/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index 96306184e42c..f87f1beb7e4e 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -115,7 +115,7 @@ def start_time_hook(module, input): get_accelerator().synchronize() module.__start_time__ = time.time() - if not hasattr(module, "__start_time_hook_handle"): + if not hasattr(module, "__start_time_hook_handle__"): module.__start_time_hook_handle__ = module.register_forward_pre_hook(start_time_hook) def end_time_hook(module, input, output): From 229960a5e9995643ce0ce957a57d847effdc41dc Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Sun, 27 Oct 2024 20:39:51 -0700 Subject: [PATCH 04/11] Add support for H100/sm_90 arch compilation (#6669) Resolves: #6549 --- op_builder/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 1609bc9005f4..461281d4a569 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -67,7 +67,7 @@ def get_default_compute_capabilities(): # Special treatment of CUDA 11.0 because compute_86 is not supported. compute_caps += ";8.0" else: - compute_caps += ";8.0;8.6" + compute_caps += ";8.0;8.6;9.0" return compute_caps From b3e959460b9ba255660b9bde3a295c2fb70d3e05 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:57:53 -0700 Subject: [PATCH 05/11] Update Gaudi2 docker image (#6677) --- .github/workflows/hpu-gaudi2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index 4e9ceb32b6b1..fdd270b89a21 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + image: vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice From e6357c28cd5cfaecab2e541c81e6d633b518e56e Mon Sep 17 00:00:00 2001 From: Raza Sikander Date: Tue, 29 Oct 2024 01:28:25 +0530 Subject: [PATCH 06/11] Update gaudi2 docker version to latest release (1.18) (#6648) Updated docker version to 1.18.0-latest Note: for this update the firmware on the Gaudi2 node had to be updated to use firmware version 1.18. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .github/workflows/hpu-gaudi2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index fdd270b89a21..9f1a9d973ca2 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest + image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice From 0e11b081be237f9b1867a3af88479a23de11345f Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:06:02 -0700 Subject: [PATCH 07/11] Update base docker image for A6000 GPU tests (#6681) Update to a [container (24.03)](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-03.html) with python 3.10 as transformers dropped support for python 3.8 in their latest release. Note: nv-human-eval.yml was never completed and isn't used, it is just updated for any potential future support. Resolves: #6672 --- .github/workflows/nv-a6000.yml | 6 +++--- .github/workflows/nv-human-eval.yml | 4 ++-- .github/workflows/nv-sd.yml | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 484948b28e34..f094c880c8b6 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -23,7 +23,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:23.03-py3 + image: nvcr.io/nvidia/pytorch:24.03-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -58,8 +58,8 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12" - python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.3" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.3" --cuda_ver="12" - name: MII unit tests run: | BRANCH="main" diff --git a/.github/workflows/nv-human-eval.yml b/.github/workflows/nv-human-eval.yml index 3de878547d6e..2ecdf218b96a 100644 --- a/.github/workflows/nv-human-eval.yml +++ b/.github/workflows/nv-human-eval.yml @@ -11,7 +11,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:23.03-py3 + image: nvcr.io/nvidia/pytorch:24.03-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -50,4 +50,4 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.0" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.3" --cuda_ver="12" diff --git a/.github/workflows/nv-sd.yml b/.github/workflows/nv-sd.yml index 0344c80451a6..5699b6055782 100644 --- a/.github/workflows/nv-sd.yml +++ b/.github/workflows/nv-sd.yml @@ -27,7 +27,7 @@ jobs: sd-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:23.03-py3 + image: nvcr.io/nvidia/pytorch:24.03-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -64,7 +64,7 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.0" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.3" --cuda_ver="12" - name: Open GitHub issue if weekly CI fails if: ${{ failure() && (github.event_name == 'schedule') }} From 07cac9e0217f14fde1a12a3c89ebe367fcee311a Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:12:29 -0700 Subject: [PATCH 08/11] Remove packages that no longer need to be updated in the latest container (#6682) --- .github/workflows/nv-a6000.yml | 2 -- .github/workflows/nv-sd.yml | 2 -- 2 files changed, 4 deletions(-) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index f094c880c8b6..639f27498dd9 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,8 +47,6 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - # Update packages included in the container that do not support pydantic 2+ to versions that do - python -m pip install thinc spacy confection --upgrade python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment diff --git a/.github/workflows/nv-sd.yml b/.github/workflows/nv-sd.yml index 5699b6055782..af406075b868 100644 --- a/.github/workflows/nv-sd.yml +++ b/.github/workflows/nv-sd.yml @@ -53,8 +53,6 @@ jobs: pip install image-similarity-measures python -m pip install opencv-python==4.6.* --force-reinstall python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - # Update packages included in the container that do not support pydantic 2+ to versions that do - python -m pip install thinc spacy confection --upgrade python -m pip install .[dev,1bit,autotuning,sd] ds_report - name: Python environment From e4a247ed133c230db58a625d8008cb60c7ae0f41 Mon Sep 17 00:00:00 2001 From: xuanhua Date: Wed, 30 Oct 2024 00:04:35 +0800 Subject: [PATCH 09/11] Fix training of pipeline based peft's lora model (#5477) Hi, guys I find there is an assert failure when I train huggingface's lora based model in pipeline style. Here is the whole steps that I created my model: 1) Load the pre-trained chatglm-6b model from huggingface, as Model_A 2) Use huggingface's peft's `get_peft_model(...)` and my `LoraConfig(...)` from Model_A to create the lora model, as Model_B 3) Create my own pipeline based model Model_C from Model_B And I run Model_C under 2 3090ti GPUs. And the assertion failure looks like this: ```text Traceback (most recent call last): File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 372, in main() File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 351, in main loss = engine.train_batch(data_iter=train_dataloader) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 375, in train_batch self._exec_schedule(sched) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 1375, in _exec_schedule self._exec_instr(**cmd.kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 276, in _exec_reduce_tied_grads dist.all_reduce(grad, group=group) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper return func(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 496, in all_reduce return cdb.all_reduce(tensor, op, group, async_op) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 159, in all_reduce return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1520, in all_reduce _check_single_tensor(tensor, "tensor") File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 463, in _check_single_tensor raise RuntimeError( RuntimeError: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor. ``` After some debugging, I find out the root cause is that my configuration of lora (in below) only add extra lora layer(part) in qkv related layers but not the embedding layer. So the whole embedding layer's parameters are freezed. ```python lora_config = LoraConfig(r=8, # copied from finetuning_lora.py lora_alpha=32, target_modules=["query_key_value"], lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", inference_mode=False, ) ``` And in my implementation of pipeline based model, I declared the embeding layer as a tied-layer. So the whole thing is that there are no gradients at all for embedding layer, but embedding layer as the tied layer needs to be synced between two gpus. The value of gradient is None but is still passed to `all_reduce` operation. Current, my fix is simple and add a check if this `grad` is None. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Heyang Qin Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/runtime/pipe/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 7ebf5487cf9e..b75270cbd306 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -287,7 +287,8 @@ def _exec_reduce_tied_grads(self): weight_group_list = self.module.get_tied_weights_and_groups() for weight, group in weight_group_list: grad = weight._hp_grad if self.using_bf16_optimizer else weight.grad - dist.all_reduce(grad, group=group) + if grad is not None: + dist.all_reduce(grad, group=group) def _exec_reduce_grads(self): self._force_grad_boundary = True From 9b547313c6c213bf6bff5227d0c9689ba1bd618a Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:36:53 -0700 Subject: [PATCH 10/11] Update checkout action to latest version (#5021) Latest checkout uses latest (non-deprecated) version of node (16 -> 20). More information [here](https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/): ``` Node.js 16 actions are deprecated. Please update the following actions to use Node.js 20: actions/checkout@v3. For more information see: https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/. ``` Checkout action: https://github.com/actions/checkout Node 20 requires a minimum of Ubuntu 20.04, so workflows currently using 18.04 are failing/will fail. --- .github/workflows/cpu-inference.yml | 2 +- .github/workflows/nv-lightning-v100.yml | 2 +- .github/workflows/nv-torch110-p40.yml | 2 +- .github/workflows/nv-torch110-v100.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cpu-inference.yml b/.github/workflows/cpu-inference.yml index fc0dac5de9a2..007313964f4a 100644 --- a/.github/workflows/cpu-inference.yml +++ b/.github/workflows/cpu-inference.yml @@ -27,7 +27,7 @@ jobs: env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv diff --git a/.github/workflows/nv-lightning-v100.yml b/.github/workflows/nv-lightning-v100.yml index 044c282ba119..f92aa7edfdd5 100644 --- a/.github/workflows/nv-lightning-v100.yml +++ b/.github/workflows/nv-lightning-v100.yml @@ -22,7 +22,7 @@ jobs: runs-on: [self-hosted, nvidia, cu121, v100] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv diff --git a/.github/workflows/nv-torch110-p40.yml b/.github/workflows/nv-torch110-p40.yml index ed639aeb3b62..31d7805db7bb 100644 --- a/.github/workflows/nv-torch110-p40.yml +++ b/.github/workflows/nv-torch110-p40.yml @@ -20,7 +20,7 @@ jobs: env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv diff --git a/.github/workflows/nv-torch110-v100.yml b/.github/workflows/nv-torch110-v100.yml index 4b9f278448ab..bb1bc987379c 100644 --- a/.github/workflows/nv-torch110-v100.yml +++ b/.github/workflows/nv-torch110-v100.yml @@ -20,7 +20,7 @@ jobs: env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - id: setup-venv uses: ./.github/workflows/setup-venv From c7f58c899f6f099a35d968bdad973f24b842c8c6 Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Thu, 31 Oct 2024 08:48:52 +0800 Subject: [PATCH 11/11] Add attribute check to support git-base autotp (#6688) Git-base model is an image-text model. After supporting the llama3.2 vision model, we set num_kv_heads dynamically. Git-base only includes vision_config, so we need to add an attribute check for vision_config/text_config when setting num_kv_heads. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/replace_module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index dece5a2c4ecf..8b1455f20c69 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -277,8 +277,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): if hasattr(model_config, "vision_config"): if "MllamaVisionEncoderLayer" in str(module): num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config) - else: + elif hasattr(model_config, "text_config"): num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config) + else: + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) else: num_kv_heads = _autotp.get_model_num_kv_heads(model_config)