Skip to content

Commit

Permalink
change kernel and add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
Liangliang-Ma committed Apr 28, 2024
1 parent c883555 commit d0c9b4b
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 9 deletions.
17 changes: 11 additions & 6 deletions csrc/xpu/packbits/packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@

#include <ipex.h>
#include <torch/extension.h>
#include <cmath>
#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace xpu;

void packbitskernel(const bool* input, uint8_t* output, const int input_size, id<1> item_ct1)
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1)
{
// get the sign bit of each float and pack them into byte
int i = item_ct1;
for (int j = 0; j < 8; ++j) {
int k = i * 8 + j;
int bit = k < input_size && input[k] != 0;
int bit = k < input_size && (!sycl::signbit(input[k]));
output[i] |= bit << (7 - j);
}
}

void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1)
{
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1
int i = item_ct1;
output[i] = (input[i / 8] >> (7 - i % 8)) & 1;
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2;
}

sycl::queue get_current_queue(at::Device device)
Expand All @@ -38,7 +41,8 @@ sycl::queue get_current_queue(at::Device device)
at::Tensor packbits(at::Tensor tensor, int input_size, int rank)
{
/*
pack bool tensor into uint8 tensor. Every eight bool elements get packed into one uint8
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8
if float x >= 0, will be packed as a '1' bit, or will be packed as '0'
Arguments:
tensor: A bool tensor that get packed.
input_size: numel of input tensor
Expand All @@ -49,9 +53,9 @@ at::Tensor packbits(at::Tensor tensor, int input_size, int rank)

int packed_size = (input_size + 7) / 8;
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU);
at::Tensor packed = torch::empty({packed_size}, unit8_options);
at::Tensor packed = torch::zeros({packed_size}, unit8_options);

bool* input = (bool*)tensor.data_ptr();
float* input = (float*)tensor.data_ptr();
uint8_t* output = (uint8_t*)packed.data_ptr();

auto event = q.submit([&](sycl::handler& cgh) {
Expand All @@ -67,6 +71,7 @@ at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank)
{
/*
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1).
Arguments:
tensor: A uint8 tensor that get unpacked.
input_size: numel of input tensor
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/comm/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import torch
# import torch_npu
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import PackbitsBuilder
Expand Down Expand Up @@ -46,11 +45,12 @@ def my_gather(self, rank, size, group, sendbuf, recvbuf, root):
dist.send(sendbuf, group=group, dst=root)

def pack(self, buffer, size):
buffer = buffer.ravel().sign_().add_(1).bool() # convert buffer to bool, element set to True if its value >=0
packed = self.packer.packbits(buffer, buffer.numel(), self.rank)
# pack float tensor into uint8 tensor
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank)
return packed.reshape(size, -1)

def unpack(self, buffer, size, dtype):
# unpack uint8 to float tensor
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank)
return unpacked.reshape(size, -1).to(dtype)

Expand Down
96 changes: 96 additions & 0 deletions tests/onebit/test_compressed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed.comm as dist
import numpy as np
import argparse
import deepspeed
import os

from deepspeed.runtime.comm.compressed import CompressedBackend
from deepspeed.accelerator import get_accelerator

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()

deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
args.local_rank = int(os.environ['LOCAL_RANK'])

get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)

size = dist.get_world_size()
rank = dist.get_rank()

backend = CompressedBackend()
local_rank = args.local_rank


# A simulated compression function using deepspeed.comm
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
get_accelerator().synchronize()
dist.barrier()
return a_server_compressed, worker_error, server_error


tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size

# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank

worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)

a_torch, worker_error_torch, server_error_torch = torch_sim(a)
get_accelerator().empty_cache()

a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank)

print(a_torch.cpu())
print(a_after.cpu())

threshold = 1e-6
magnitude_threshold = 1e-6
diff_mask = (a_after - a_torch) > threshold
diff_server_mask = torch.chunk(diff_mask, size)[rank]
mpi_server = torch.chunk(a_after, size)[rank] + server_error
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

test_correctness = True

# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
# The test would skip those numbers that are too small in compensated_server_m
if test_correctness:
if torch.sum(diff_server_mask) == 0:
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
else:
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
if torch.sum(check_mag_mask) == 0:
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank))
else:
print('Fails at {} of positions'.format(torch.sum(check_mag_mask)))
97 changes: 97 additions & 0 deletions tests/onebit/test_compressed_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import deepspeed.comm as dist
import numpy as np
import argparse
import deepspeed
import os

from deepspeed.runtime.comm.compressed import CompressedBackend
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.accelerator import get_accelerator
from statistics import mean

timers = SynchronizedWallClockTimer()

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()

deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name())
args.local_rank = int(os.environ['LOCAL_RANK'])

get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)

size = dist.get_world_size()
rank = dist.get_rank()

backend = CompressedBackend()
local_rank = args.local_rank

# Setting tensor_size (BERT-Large)
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
right_tensor_size = tensor_size
right_server_size = right_tensor_size // size

# Adding bias to the initialization of the gradient we are communicating
# In order to get rid of the case where some elements in the gradient are too small
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank

worker_error = torch.zeros(right_tensor_size, device=device)
server_error = torch.zeros(right_server_size, device=device)

warmup = 10
iters = 10

# Warmup
for i in range(warmup):
backend.compressed_allreduce(a, worker_error, server_error, local_rank)

time_list = []

a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign

print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None

for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
#deepspeed.comm.all_reduce(a_compressed)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())

#timer_names = ['compressed_allreduce']
#timers.log(names=timer_names, normalizer=1, memory_breakdown=None)

places = 2
convert = 1e3
float_size = 4

if rank == 0:
for i in range(iters):
lat = time_list[i]
print("latency = ", lat * convert)

minlat = round(min(time_list) * convert)
maxlat = round(max(time_list) * convert)
meanlat = round(mean(time_list) * convert, places)
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None
#print("tensor shape", a.shape)
duration = meanlat / 1e3
tput = ((tensor_size * 4) / duration)
print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None
size = tensor_size * 4
n = dist.get_world_size()
busbw = (size / duration) * (2 * (n - 1) / n)
print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None

0 comments on commit d0c9b4b

Please sign in to comment.