-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from Liangliang-Ma/onebitadam
Onebitadam
- Loading branch information
Showing
10 changed files
with
473 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include <ipex.h> | ||
#include <torch/extension.h> | ||
#include <iostream> | ||
#include <sycl/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace xpu; | ||
|
||
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1) | ||
{ | ||
// get the sign bit of each float and pack them into byte | ||
int i = item_ct1; | ||
for (int j = 0; j < 8; ++j) { | ||
int k = i * 8 + j; | ||
int bit = k < input_size && (!sycl::signbit(input[k])); | ||
output[i] |= bit << (7 - j); | ||
} | ||
} | ||
|
||
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1) | ||
{ | ||
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1 | ||
int i = item_ct1; | ||
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2; | ||
} | ||
|
||
sycl::queue get_current_queue(at::Device device) | ||
{ | ||
c10::impl::VirtualGuardImpl impl(device.type()); | ||
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); | ||
sycl::queue queue = xpu::get_queue_from_stream(_stream); | ||
return queue; | ||
} | ||
|
||
at::Tensor packbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
/* | ||
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8 | ||
if float x >= 0, will be packed as a '1' bit, or will be packed as '0' | ||
Arguments: | ||
tensor: A bool tensor that get packed. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
int packed_size = (input_size + 7) / 8; | ||
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU); | ||
at::Tensor packed = torch::zeros({packed_size}, unit8_options); | ||
|
||
float* input = (float*)tensor.data_ptr(); | ||
uint8_t* output = (uint8_t*)packed.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) { | ||
packbitskernel(input, output, input_size, item_ct1); | ||
}); | ||
}); | ||
|
||
return packed; | ||
} | ||
|
||
at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
/* | ||
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float | ||
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1). | ||
Arguments: | ||
tensor: A uint8 tensor that get unpacked. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU); | ||
at::Tensor unpacked = torch::empty({input_size * 8}, float_options); | ||
|
||
uint8_t* input = (uint8_t*)tensor.data_ptr(); | ||
float* output = (float*)unpacked.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(input_size * 8), | ||
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); }); | ||
}); | ||
|
||
return unpacked; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)"); | ||
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import numpy as np | ||
import torch | ||
import deepspeed.comm as dist | ||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed.ops.op_builder import PackbitsBuilder | ||
|
||
|
||
class CompressedBackend(object): | ||
|
||
def __init__(self, mpu=None): | ||
if mpu is None: | ||
self.world_group = dist.new_group(ranks=range(dist.get_world_size())) | ||
else: | ||
self.mpu = mpu | ||
self.world_group = self.mpu.get_data_parallel_group() | ||
self.size = dist.get_world_size(group=self.world_group) | ||
self.rank = dist.get_rank(group=self.world_group) | ||
self.packer = PackbitsBuilder().load() | ||
|
||
def my_igather(self, rank, size, group, sendbuf, recvbuf, root): | ||
req = [] | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
req.append(dist.isend(sendbuf, group=group, dst=root)) | ||
return req | ||
|
||
def my_gather(self, rank, size, group, sendbuf, recvbuf, root): | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
dist.recv(recvbuf[idx], src=idx, group=group) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
dist.send(sendbuf, group=group, dst=root) | ||
|
||
def pack(self, buffer, size): | ||
# pack float tensor into uint8 tensor | ||
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank) | ||
return packed.reshape(size, -1) | ||
|
||
def unpack(self, buffer, size, dtype): | ||
# unpack uint8 to float tensor | ||
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank) | ||
return unpacked.reshape(size, -1).to(dtype) | ||
|
||
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): | ||
original_shape = buffer_m.size() | ||
if len(original_shape) > 1: | ||
buffer_m = torch.flatten(buffer_m) | ||
|
||
# align size of original_buffer and error | ||
original_size = buffer_m.numel() | ||
worker_error_size = worker_error.numel() | ||
if original_size != worker_error_size: | ||
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) | ||
buffer_m = torch.cat([buffer_m, empty_tensor]) | ||
|
||
buffer_m.add_(worker_error) | ||
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) | ||
|
||
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8) | ||
|
||
recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], | ||
dtype=sign_list_packed_tmp[0].dtype, | ||
device=sign_list_packed_tmp.device) | ||
|
||
sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] | ||
|
||
recvbuf_scale = [ | ||
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name()) | ||
for _ in range(self.size) | ||
] | ||
|
||
# communication phase 1 | ||
# all to all for sign | ||
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) | ||
# all gather for scale | ||
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) | ||
|
||
flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() | ||
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \ | ||
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) | ||
|
||
compensated_server_m.add_(server_error) | ||
|
||
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) | ||
|
||
server_error.set_(compensated_server_m - | ||
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8) | ||
|
||
# recvbuf_sign_server | ||
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], | ||
dtype=recvbuf_sign.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# recvbuf_scale_server | ||
recvbuf_scale_server_tmp = torch.zeros([self.size, 1], | ||
dtype=worker_scale.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# communication Phase 2 | ||
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) | ||
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) | ||
|
||
recvbuf_sign_server = torch.stack(recvbuf_sign_server) | ||
|
||
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() | ||
|
||
buffer_m.data.copy_( | ||
self.unpack(flattened_recvbuf_sign_server, self.size, | ||
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) | ||
|
||
if original_size != worker_error_size: | ||
buffer_m = buffer_m[0:original_size] | ||
if len(original_shape) > 1: | ||
buffer_m = buffer_m.reshape(original_shape) | ||
|
||
return buffer_m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
from .builder import SYCLOpBuilder | ||
|
||
|
||
class PackbitsBuilder(SYCLOpBuilder): | ||
BUILD_VAR = "DS_BUILD_PACK_BITS" | ||
NAME = "pack_bits" | ||
|
||
def __init__(self): | ||
super().__init__(name=self.NAME) | ||
|
||
def absolute_name(self): | ||
return f'deepspeed.ops.{self.NAME}_op' | ||
|
||
def sources(self): | ||
return ['csrc/xpu/packbits/packing.cpp'] | ||
|
||
def include_paths(self): | ||
return ['csrc/xpu/includes'] | ||
|
||
def cxx_args(self): | ||
args = super().cxx_args() | ||
return args + self.version_dependent_macros() |
Oops, something went wrong.