From 236d13db9c137f697f314d86fd62b51d17f326a3 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Tue, 1 Oct 2024 14:15:07 +0300 Subject: [PATCH 1/2] Allow to compile collective for PT > 2.3 --- deepspeed/comm/torch.py | 50 ++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index ed2645d415c4..48bd18e24b7e 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -19,6 +19,10 @@ DS_COMM_ALL_REDUCE_OFF = False DS_COMM_REDUCE_OFF = False +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") @@ -114,7 +118,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod - @compiler.disable + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -123,7 +127,7 @@ def get_all_gather_function(self): return None @classmethod - @compiler.disable + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -146,19 +150,19 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' - @compiler.disable + @disable_compiler_collective def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) def inference_all_reduce(self, tensor, op, group=None): - if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'): + if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1: op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False) else: return torch.ops.deepspeed.inference_all_reduce_(tensor) - @compiler.disable + @disable_compiler_collective def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -169,7 +173,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -177,7 +181,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -190,7 +194,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -199,7 +203,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -208,7 +212,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -216,7 +220,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -234,7 +238,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -258,7 +262,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() - @compiler.disable + @disable_compiler_collective def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -272,7 +276,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_to_all_single(self, output, input, @@ -287,27 +291,27 @@ def all_to_all_single(self, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -315,7 +319,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -323,13 +327,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) - @compiler.disable + @disable_compiler_collective def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD From 0ce079a463238df6395216e94438b826ff161146 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Mon, 28 Oct 2024 09:50:06 +0200 Subject: [PATCH 2/2] fix formatting --- deepspeed/comm/torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 4cdf833d5a12..ebfb396ee5e9 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -19,11 +19,13 @@ DS_COMM_ALL_REDUCE_OFF = False DS_COMM_REDUCE_OFF = False + def disable_compiler_collective(func): if required_torch_version(min_version=2.3): return func return compiler.disable(func) + def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: