From 9f96ad4049b1fb63d38a1a090480dbef61dc0490 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 May 2024 18:45:49 +0000 Subject: [PATCH] consolidate communication for tensor metadata --- deepspeed/runtime/pipe/engine.py | 106 ++++++++++++------------------- 1 file changed, 40 insertions(+), 66 deletions(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 7d9166607444..fe257f225b30 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -42,6 +42,8 @@ PIPE_RECV_INPUT_TIMER = 'pipe_recv_input' PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad' +TENSOR_META_SIZE = 256 + def is_even(number): return number % 2 == 0 @@ -930,17 +932,17 @@ def _send_tensor_meta(self, buffer, recv_stage): * ndims * shape """ - send_bytes = 0 + meta_buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) if isinstance(buffer, torch.Tensor): - send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[buffer.dtype]]).to(self.device) - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.send(type_tensor, recv_stage) - send_shape = torch.LongTensor(data=buffer.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device) - p2p.send(send_dtype, recv_stage) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(buffer) + meta_buf_list = [ + 0, # type of data (0: tensor, 1: list, 2: tuple) + self.DTYPE_TO_ID[buffer.dtype], # dtype + len(buffer.size()) # ndims + ] + meta_buf_list.extend(buffer.size()) + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + elif isinstance(buffer, list): assert (False) type_tensor = torch.LongTensor(data=[1]).to(self.device) @@ -953,30 +955,21 @@ def _send_tensor_meta(self, buffer, recv_stage): send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) p2p.send(send_ndims, recv_stage) p2p.send(send_shape, recv_stage) - send_bytes += _tensor_bytes(tensor) elif isinstance(buffer, tuple): - type_tensor = torch.LongTensor(data=[2]).to(self.device) - p2p.send(type_tensor, recv_stage) - count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device) - p2p.send(count_tensor, recv_stage) - for idx, tensor in enumerate(buffer): + meta_buf_list = [ + 2, # type of data (0: tensor, 1: list, 2: tuple) + len(buffer) # num_tensors + ] + + for tensor in buffer: assert isinstance(tensor, torch.Tensor) - send_shape = torch.LongTensor(data=tensor.size()).to(self.device) - send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device) - send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device) - p2p.send(send_dtype, recv_stage) - p2p.send(send_ndims, recv_stage) - p2p.send(send_shape, recv_stage) - # Useful for performance debugging. - ''' - new_bytes = _tensor_bytes(tensor) - send_bytes += _tensor_bytes(tensor) - # Useful for performance debugging. - if self.grid.data_parallel_id == 0: - print( - f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB' - ) - ''' + meta_buf_list.append(self.DTYPE_TO_ID[tensor.dtype]) + meta_buf_list.append(len(tensor.size())) + meta_buf_list.extend(tensor.size()) + + meta_buffer[:len(meta_buf_list)].copy_(torch.tensor(meta_buf_list, dtype=torch.int32)) + p2p.send(meta_buffer, recv_stage) + else: raise NotImplementedError(f'Could not send meta type {type(buffer)}') @@ -989,53 +982,34 @@ def _send_tensor_meta(self, buffer, recv_stage): def _recv_tensor_meta(self, send_stage): """Receive metadata about upcoming p2p transfers and return allocated buffers. - Metadata is communicated in this order: - * type (0: tensor, 1: list) - * num_tensors if type=list - foreach tensor in buffer: - * ndims - * shape - Returns: Allocated buffer for receiving from send_stage. """ + buffer = torch.empty(TENSOR_META_SIZE, dtype=torch.int32, device=self.device) + p2p.recv(buffer, send_stage) - type_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(type_tensor, send_stage) - recv_type = type_tensor.item() + recv_type = buffer[0].item() # A single tensor will be sent. if recv_type == 0: - recv_dtype = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_dtype, send_stage) - recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()] - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shape = recv_shape.tolist() + recv_dtype = self.ID_TO_DTYPE[buffer[1].item()] + recv_ndims = buffer[2].item() + recv_shape = buffer[3:3 + recv_ndims].tolist() return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype) # List or tuple of tensors elif recv_type == 1 or recv_type == 2: - count_tensor = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(count_tensor, send_stage) - num_tensors = count_tensor.item() - recv_shapes_and_dtypes = [] + num_tensors = buffer[1].item() + buffers = [] + offset = 2 for idx in range(num_tensors): - recv_dtype = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_dtype, send_stage) - recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()] - recv_ndims = torch.LongTensor(data=[0]).to(self.device) - p2p.recv(recv_ndims, send_stage) - recv_ndims = recv_ndims.item() - recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device) - p2p.recv(recv_shape, send_stage) - recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype)) - - buffers.append(self._allocate_or_extend_buffers(idx, recv_shape.tolist(), recv_dtype)) + recv_dtype = self.ID_TO_DTYPE[buffer[offset].item()] + recv_ndims = buffer[offset + 1].item() + recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist() + offset += 2 + recv_ndims + + buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype)) # Convert to tuples if requested. if recv_type == 2: