diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 988b74232bb9..5461ae18d1f0 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -20,6 +20,12 @@ DS_COMM_REDUCE_OFF = False +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) + + def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod - @compiler.disable + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -123,7 +129,7 @@ def get_all_gather_function(self): return None @classmethod - @compiler.disable + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' - @compiler.disable + @disable_compiler_collective def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None): else: return torch.ops.deepspeed.inference_all_reduce_(tensor) - @compiler.disable + @disable_compiler_collective def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -169,7 +175,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -177,7 +183,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -208,7 +214,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -216,7 +222,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() - @compiler.disable + @disable_compiler_collective def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_to_all_single(self, output, input, @@ -287,27 +293,27 @@ def all_to_all_single(self, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -315,7 +321,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -323,13 +329,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) - @compiler.disable + @disable_compiler_collective def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD