Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/pr_6772
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 26, 2024
2 parents c5455b0 + 85cc5f9 commit f1920b2
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 61 deletions.
45 changes: 24 additions & 21 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,52 @@

# top-level repo folders
/.github/ @loadams
/azure/ @awan-10
/benchmarks/ @awan-10 @tjruwase
/azure/ @loadams
/benchmarks/ @guanhuawang @tjruwase
/bin/ @loadams
/csrc/ @awan-10
/csrc/ @tjruwase
/deepspeed/ @loadams @tjruwase
/docker/ @awan-10
/docker/ @loadams @guanhuawang
/docs/ @loadams @tjruwase
/examples/ @awan-10 @tohtana
/examples/ @jomayeri @tohtana
/op_builder/ @loadams @tjruwase @jomayeri
/release/ @loadams
/release/ @loadams @jomayeri
/requirements/ @loadams
/scripts/ @awan-10
/scripts/ @loadams @tjruwase
/tests/ @tjruwase @loadams @tohtana

# deepspeed
/deepspeed/autotuning/ @loadams
/deepspeed/checkpoint/ @tjruwase
/deepspeed/comm/ @awan-10
/deepspeed/comm/ @guanhuawang
/deepspeed/compression/ @tjruwase
/deepspeed/elasticity/ @awan-10
/deepspeed/elasticity/ @tjruwase
/deepspeed/launcher/ @loadams
/deepspeed/module_inject/ @awan-10
/deepspeed/module_inject/ @hwchen2017 @loadams
/deepspeed/moe/ @tohtana
/deepspeed/monitor/ @awan-10
/deepspeed/monitor/ @tjruwase
/deepspeed/nebula/ @tjruwase
/deepspeed/nvme/ @tjruwase @jomayeri
/deepspeed/ops/ @tohtana
/deepspeed/pipe/ @tohtana @loadams
/deepspeed/profiling/ @loadams
/deepspeed/utils/ @tjruwase @awan-10
/deepspeed/sequence/ @tohtana
/deepspeed/utils/ @tjruwase @tohtana

# inference
/deepspeed/inference/ @awan-10
/deepspeed/model_implementations/ @awan-10
/deepspeed/inference/ @hwchen2017 @tohtana
/deepspeed/model_implementations/@tohtana @loadams

# training
/deepspeed/runtime/ @tjruwase @tohtana
/deepspeed/runtime/activation_checkpointing/ @tjruwase
/deepspeed/runtime/checkpoint_engine/ @tjruwase
/deepspeed/runtime/comm/ @awan-10
/deepspeed/runtime/compression/ @awan-10
/deepspeed/runtime/comm/ @guanhuawang
/deepspeed/runtime/compression/ @tjruwase
/deepspeed/runtime/data_pipeline/ @tjruwase
/deepspeed/runtime/fp16/ @tjruwase
/deepspeed/runtime/fp16/onebit/ @awan-10
/deepspeed/runtime/pipe/ @loadams
/deepspeed/runtime/swap_tensor/ @tjruwase
/deepspeed/runtime/zero/ @tjruwase
/deepspeed/runtime/domino/ @guanhuawang @hwchen2017
/deepspeed/runtime/fp16/ @tjruwase @tohtana
/deepspeed/runtime/fp16/onebit/ @tjruwase
/deepspeed/runtime/pipe/ @loadams @tohtana
/deepspeed/runtime/swap_tensor/ @tjruwase @jomayeri
/deepspeed/runtime/zero/ @tjruwase @tohtana
2 changes: 1 addition & 1 deletion accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self):
self.apply_hpu_workarounds()
try:
import habana_frameworks.torch.hpu as hpu
hpu.setDeterministic(True)
self.hpu = hpu
torch.use_deterministic_algorithms(True)
except ImportError as e:
raise ValueError(
f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
Expand Down
50 changes: 28 additions & 22 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
DS_COMM_REDUCE_OFF = False


def disable_compiler_collective(func):
if required_torch_version(min_version=2.3):
return func
return compiler.disable(func)


def build_shm_op():
builder = get_accelerator().create_op_builder("ShareMemCommBuilder")
if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]:
Expand Down Expand Up @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
self.shm_comm_op.initialize(self.get_world_size(), self.get_rank())

@classmethod
@compiler.disable
@disable_compiler_collective
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
Expand All @@ -123,7 +129,7 @@ def get_all_gather_function(self):
return None

@classmethod
@compiler.disable
@disable_compiler_collective
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
Expand All @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

@compiler.disable
@disable_compiler_collective
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
Expand All @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None):
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
@disable_compiler_collective
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
Expand All @@ -169,15 +175,15 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_OFF:
if int(os.getenv('RANK', '0')) == 0:
utils.logger.warning("REDUCE is OFF")
return Noop()
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if DS_COMM_REDUCE_SCATTER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def broadcast(self, tensor, src, group=None, async_op=False):
if DS_COMM_BROADCAST_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -208,15 +214,15 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
else:
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
if int(os.getenv('RANK', '0')) == 0:
Expand All @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
Expand All @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
else:
reqs[-1].wait()

@compiler.disable
@disable_compiler_collective
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
Expand All @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr
"please consider upgrading your pytorch installation.")
pass

@compiler.disable
@disable_compiler_collective
def all_to_all_single(self,
output,
input,
Expand All @@ -287,49 +293,49 @@ def all_to_all_single(self,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False):
return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)

@compiler.disable
@disable_compiler_collective
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)

@compiler.disable
@disable_compiler_collective
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)

@compiler.disable
@disable_compiler_collective
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)

@compiler.disable
@disable_compiler_collective
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,15 @@ def _replace(self, child, name, conv_linear_layer):
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand Down
21 changes: 10 additions & 11 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import deepspeed
from deepspeed import comm as dist
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator


Expand Down Expand Up @@ -97,7 +96,7 @@ def backward(ctx, grad_output):
return grad_output

# Async All-reduce.
handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True)
ctx.handle_dic[ctx.h_id] = handle
return None, grad_output, None, None

Expand Down Expand Up @@ -249,6 +248,10 @@ def __init__(self,
output_bias=None):
super(DominoTransformerLayer, self).__init__()

if not dist.is_initialized():
dist.init_distributed()
assert dist.is_initialized(), "deepspeed.comm is not initialized!"

self.llama_model = config.llama_model
self.layer_number = layer_number
self.layer_type = layer_type
Expand Down Expand Up @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = deepspeed.comm.all_reduce(attention_output0,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = deepspeed.comm.all_reduce(attention_output1,
group=self.mpu.get_tensor_model_parallel_group(),
async_op=True)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()

# Residual0 connection.
Expand Down Expand Up @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
output0 = output0 + bias_c
output0 = self.mlp_activation_func(output0)
output0 = torch.matmul(output0, self.weight_r.t())
handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

handle1.wait()

Expand All @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if bias_c is not None:
output1 = output1 + bias_c
output1 = torch.matmul(output1, self.weight_r.t())
deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())
dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group())

handle2.wait()

Expand Down
9 changes: 6 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self,
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
param.param_idx_in_group = len(trainable_parameters)
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)

Expand Down Expand Up @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))

#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
Expand Down Expand Up @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor):

process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[i][param_idx_in_group]

process_group = self.dp_process_group

Expand Down Expand Up @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self):
stream = get_accelerator().current_stream()

with get_accelerator().stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[group_idx][param_idx_in_group]

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Expand Down
Loading

0 comments on commit f1920b2

Please sign in to comment.