diff --git a/csrc/xpu/packbits/packing.cpp b/csrc/xpu/packbits/packing.cpp index 86d4e799b772..385b83926ce7 100644 --- a/csrc/xpu/packbits/packing.cpp +++ b/csrc/xpu/packbits/packing.cpp @@ -5,26 +5,29 @@ #include #include +#include #include #include using namespace sycl; using namespace xpu; -void packbitskernel(const bool* input, uint8_t* output, const int input_size, id<1> item_ct1) +void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1) { + // get the sign bit of each float and pack them into byte int i = item_ct1; for (int j = 0; j < 8; ++j) { int k = i * 8 + j; - int bit = k < input_size && input[k] != 0; + int bit = k < input_size && (!sycl::signbit(input[k])); output[i] |= bit << (7 - j); } } void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1) { + // use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1 int i = item_ct1; - output[i] = (input[i / 8] >> (7 - i % 8)) & 1; + output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2; } sycl::queue get_current_queue(at::Device device) @@ -38,7 +41,8 @@ sycl::queue get_current_queue(at::Device device) at::Tensor packbits(at::Tensor tensor, int input_size, int rank) { /* - pack bool tensor into uint8 tensor. Every eight bool elements get packed into one uint8 + pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8 + if float x >= 0, will be packed as a '1' bit, or will be packed as '0' Arguments: tensor: A bool tensor that get packed. input_size: numel of input tensor @@ -49,9 +53,9 @@ at::Tensor packbits(at::Tensor tensor, int input_size, int rank) int packed_size = (input_size + 7) / 8; auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU); - at::Tensor packed = torch::empty({packed_size}, unit8_options); + at::Tensor packed = torch::zeros({packed_size}, unit8_options); - bool* input = (bool*)tensor.data_ptr(); + float* input = (float*)tensor.data_ptr(); uint8_t* output = (uint8_t*)packed.data_ptr(); auto event = q.submit([&](sycl::handler& cgh) { @@ -67,6 +71,7 @@ at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank) { /* unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float + a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1). Arguments: tensor: A uint8 tensor that get unpacked. input_size: numel of input tensor diff --git a/deepspeed/runtime/comm/compressed.py b/deepspeed/runtime/comm/compressed.py index 3950839895be..7f8c7395451d 100644 --- a/deepspeed/runtime/comm/compressed.py +++ b/deepspeed/runtime/comm/compressed.py @@ -5,7 +5,6 @@ import numpy as np import torch -# import torch_npu import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import PackbitsBuilder @@ -46,11 +45,12 @@ def my_gather(self, rank, size, group, sendbuf, recvbuf, root): dist.send(sendbuf, group=group, dst=root) def pack(self, buffer, size): - buffer = buffer.ravel().sign_().add_(1).bool() # convert buffer to bool, element set to True if its value >=0 - packed = self.packer.packbits(buffer, buffer.numel(), self.rank) + # pack float tensor into uint8 tensor + packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank) return packed.reshape(size, -1) def unpack(self, buffer, size, dtype): + # unpack uint8 to float tensor unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank) return unpacked.reshape(size, -1).to(dtype) diff --git a/tests/onebit/test_compressed_backend.py b/tests/onebit/test_compressed_backend.py new file mode 100644 index 000000000000..f6919a09a54b --- /dev/null +++ b/tests/onebit/test_compressed_backend.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +import numpy as np +import argparse +import deepspeed +import os + +from deepspeed.runtime.comm.compressed import CompressedBackend +from deepspeed.accelerator import get_accelerator + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', type=int, default=-1) +args = parser.parse_args() + +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) +args.local_rank = int(os.environ['LOCAL_RANK']) + +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) + +size = dist.get_world_size() +rank = dist.get_rank() + +backend = CompressedBackend() +local_rank = args.local_rank + + +# A simulated compression function using deepspeed.comm +def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + get_accelerator().synchronize() + dist.barrier() + return a_server_compressed, worker_error, server_error + + +tensor_size = 300 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size + +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) + +a_torch, worker_error_torch, server_error_torch = torch_sim(a) +get_accelerator().empty_cache() + +a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) + +print(a_torch.cpu()) +print(a_after.cpu()) + +threshold = 1e-6 +magnitude_threshold = 1e-6 +diff_mask = (a_after - a_torch) > threshold +diff_server_mask = torch.chunk(diff_mask, size)[rank] +mpi_server = torch.chunk(a_after, size)[rank] + server_error +torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + +test_correctness = True + +# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic +# The test would skip those numbers that are too small in compensated_server_m +if test_correctness: + if torch.sum(diff_server_mask) == 0: + print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) + else: + check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold + if torch.sum(check_mag_mask) == 0: + print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) + else: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) diff --git a/tests/onebit/test_compressed_perf.py b/tests/onebit/test_compressed_perf.py new file mode 100644 index 000000000000..a686af0f6b8d --- /dev/null +++ b/tests/onebit/test_compressed_perf.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed.comm as dist +import numpy as np +import argparse +import deepspeed +import os + +from deepspeed.runtime.comm.compressed import CompressedBackend +from deepspeed.utils.timer import SynchronizedWallClockTimer +from deepspeed.accelerator import get_accelerator +from statistics import mean + +timers = SynchronizedWallClockTimer() + +parser = argparse.ArgumentParser() +parser.add_argument('--local_rank', type=int, default=-1) +args = parser.parse_args() + +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) +args.local_rank = int(os.environ['LOCAL_RANK']) + +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) + +size = dist.get_world_size() +rank = dist.get_rank() + +backend = CompressedBackend() +local_rank = args.local_rank + +# Setting tensor_size (BERT-Large) +tensor_size = 300 * 2**20 +server_size = int(tensor_size / size) +if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) +else: + right_tensor_size = tensor_size +right_server_size = right_tensor_size // size + +# Adding bias to the initialization of the gradient we are communicating +# In order to get rid of the case where some elements in the gradient are too small +a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + +worker_error = torch.zeros(right_tensor_size, device=device) +server_error = torch.zeros(right_server_size, device=device) + +warmup = 10 +iters = 10 + +# Warmup +for i in range(warmup): + backend.compressed_allreduce(a, worker_error, server_error, local_rank) + +time_list = [] + +a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) +scale = a.norm() / np.sqrt(a.numel()) +a_compressed = scale * a_sign + +print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None + +for i in range(iters): + timers('compressed_allreduce').start() + backend.compressed_allreduce(a, worker_error, server_error, local_rank) + #deepspeed.comm.all_reduce(a_compressed) + timers('compressed_allreduce').stop() + time_list.append(timers('compressed_allreduce').elapsed()) + +#timer_names = ['compressed_allreduce'] +#timers.log(names=timer_names, normalizer=1, memory_breakdown=None) + +places = 2 +convert = 1e3 +float_size = 4 + +if rank == 0: + for i in range(iters): + lat = time_list[i] + print("latency = ", lat * convert) + +minlat = round(min(time_list) * convert) +maxlat = round(max(time_list) * convert) +meanlat = round(mean(time_list) * convert, places) +print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None +#print("tensor shape", a.shape) +duration = meanlat / 1e3 +tput = ((tensor_size * 4) / duration) +print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None +size = tensor_size * 4 +n = dist.get_world_size() +busbw = (size / duration) * (2 * (n - 1) / n) +print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None