diff --git a/CODEOWNERS b/CODEOWNERS index c0fc85cb8b89..b0d3b8b0d77b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -8,49 +8,52 @@ # top-level repo folders /.github/ @loadams -/azure/ @awan-10 -/benchmarks/ @awan-10 @tjruwase +/azure/ @loadams +/benchmarks/ @guanhuawang @tjruwase /bin/ @loadams -/csrc/ @awan-10 +/csrc/ @tjruwase /deepspeed/ @loadams @tjruwase -/docker/ @awan-10 +/docker/ @loadams @guanhuawang /docs/ @loadams @tjruwase -/examples/ @awan-10 @tohtana +/examples/ @jomayeri @tohtana /op_builder/ @loadams @tjruwase @jomayeri -/release/ @loadams +/release/ @loadams @jomayeri /requirements/ @loadams -/scripts/ @awan-10 +/scripts/ @loadams @tjruwase /tests/ @tjruwase @loadams @tohtana # deepspeed /deepspeed/autotuning/ @loadams /deepspeed/checkpoint/ @tjruwase -/deepspeed/comm/ @awan-10 +/deepspeed/comm/ @guanhuawang /deepspeed/compression/ @tjruwase -/deepspeed/elasticity/ @awan-10 +/deepspeed/elasticity/ @tjruwase /deepspeed/launcher/ @loadams -/deepspeed/module_inject/ @awan-10 +/deepspeed/module_inject/ @hwchen2017 @loadams /deepspeed/moe/ @tohtana -/deepspeed/monitor/ @awan-10 +/deepspeed/monitor/ @tjruwase /deepspeed/nebula/ @tjruwase +/deepspeed/nvme/ @tjruwase @jomayeri /deepspeed/ops/ @tohtana /deepspeed/pipe/ @tohtana @loadams /deepspeed/profiling/ @loadams -/deepspeed/utils/ @tjruwase @awan-10 +/deepspeed/sequence/ @tohtana +/deepspeed/utils/ @tjruwase @tohtana # inference -/deepspeed/inference/ @awan-10 -/deepspeed/model_implementations/ @awan-10 +/deepspeed/inference/ @hwchen2017 @tohtana +/deepspeed/model_implementations/@tohtana @loadams # training /deepspeed/runtime/ @tjruwase @tohtana /deepspeed/runtime/activation_checkpointing/ @tjruwase /deepspeed/runtime/checkpoint_engine/ @tjruwase -/deepspeed/runtime/comm/ @awan-10 -/deepspeed/runtime/compression/ @awan-10 +/deepspeed/runtime/comm/ @guanhuawang +/deepspeed/runtime/compression/ @tjruwase /deepspeed/runtime/data_pipeline/ @tjruwase -/deepspeed/runtime/fp16/ @tjruwase -/deepspeed/runtime/fp16/onebit/ @awan-10 -/deepspeed/runtime/pipe/ @loadams -/deepspeed/runtime/swap_tensor/ @tjruwase -/deepspeed/runtime/zero/ @tjruwase +/deepspeed/runtime/domino/ @guanhuawang @hwchen2017 +/deepspeed/runtime/fp16/ @tjruwase @tohtana +/deepspeed/runtime/fp16/onebit/ @tjruwase +/deepspeed/runtime/pipe/ @loadams @tohtana +/deepspeed/runtime/swap_tensor/ @tjruwase @jomayeri +/deepspeed/runtime/zero/ @tjruwase @tohtana diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 723a66e4c6fb..b46351f8ca43 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -21,8 +21,8 @@ def __init__(self): self.apply_hpu_workarounds() try: import habana_frameworks.torch.hpu as hpu - hpu.setDeterministic(True) self.hpu = hpu + torch.use_deterministic_algorithms(True) except ImportError as e: raise ValueError( f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 988b74232bb9..5461ae18d1f0 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -20,6 +20,12 @@ DS_COMM_REDUCE_OFF = False +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) + + def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod - @compiler.disable + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -123,7 +129,7 @@ def get_all_gather_function(self): return None @classmethod - @compiler.disable + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' - @compiler.disable + @disable_compiler_collective def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None): else: return torch.ops.deepspeed.inference_all_reduce_(tensor) - @compiler.disable + @disable_compiler_collective def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -169,7 +175,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -177,7 +183,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -208,7 +214,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -216,7 +222,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() - @compiler.disable + @disable_compiler_collective def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_to_all_single(self, output, input, @@ -287,27 +293,27 @@ def all_to_all_single(self, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -315,7 +321,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -323,13 +329,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) - @compiler.disable + @disable_compiler_collective def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 221d490a37d2..5441000e581d 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer): weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size(), False) return LinearAllreduce(weight, bias, self.mp_group) + # For Arctic model, bypass to all_reduce replacement for w2 weights + arctic_w2_all_reduce_linear = False + if 'Arctic' in str(self.module) and 'w2' in name: + arctic_w2_all_reduce_linear = True # For MLP including chunk layer. if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) return LinearLayer(weight=weight, bias=bias) - if name in self.all_reduce_linears: + if name in self.all_reduce_linears or arctic_w2_all_reduce_linear: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,6 +248,10 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait() diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 7ac89a233808..ecb2a527f870 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -310,6 +310,7 @@ def __init__(self, for param in param_group['params']: if param.requires_grad: param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param, param_id)) + self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: + for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md index d5a08b27bf4d..6488f9b718fe 100755 --- a/docs/_tutorials/automatic-tensor-parallelism.md +++ b/docs/_tutorials/automatic-tensor-parallelism.md @@ -121,6 +121,7 @@ The following results were collected using V100 SXM2 32GB GPUs. The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet. - albert +- arctic - baichuan - bert - bigbird_pegasus diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 9ee546437f6c..c67a907c6785 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -93,7 +93,8 @@ def strict_average_tensor(tensor): process_group = optimizer.dp_process_group curr_size = 0 pg_offsets = [] - for i, param, param_id in optimizer.params_in_ipg_bucket: + for i, param_idx, param_id in optimizer.params_in_ipg_bucket: + param = optimizer.bit16_groups[i][param_idx] process_group = optimizer.dp_process_group if optimizer.ipg_bucket_has_moe_params: process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( diff --git a/version.txt b/version.txt index 201a22c8fa5c..7eb3095a3295 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.16.2 +0.16.3