-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c883555
commit d0c9b4b
Showing
4 changed files
with
207 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import deepspeed.comm as dist | ||
import numpy as np | ||
import argparse | ||
import deepspeed | ||
import os | ||
|
||
from deepspeed.runtime.comm.compressed import CompressedBackend | ||
from deepspeed.accelerator import get_accelerator | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--local_rank', type=int, default=-1) | ||
args = parser.parse_args() | ||
|
||
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) | ||
args.local_rank = int(os.environ['LOCAL_RANK']) | ||
|
||
get_accelerator().set_device(args.local_rank) | ||
device = torch.device(get_accelerator().device_name(), args.local_rank) | ||
|
||
size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
|
||
backend = CompressedBackend() | ||
local_rank = args.local_rank | ||
|
||
|
||
# A simulated compression function using deepspeed.comm | ||
def torch_sim(a): | ||
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) | ||
scale = a.norm() / np.sqrt(a.numel()) | ||
a_compressed = scale * a_sign | ||
a_sign = None | ||
worker_error = a - a_compressed | ||
dist.all_reduce(a_compressed) | ||
a_compressed.mul_(1 / dist.get_world_size()) | ||
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) | ||
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) | ||
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list] | ||
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) | ||
a_server_compressed = torch.cat([server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) | ||
rank = dist.get_rank() | ||
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] | ||
get_accelerator().synchronize() | ||
dist.barrier() | ||
return a_server_compressed, worker_error, server_error | ||
|
||
|
||
tensor_size = 300 * 2**20 | ||
server_size = int(tensor_size / size) | ||
if tensor_size % (8 * size) != 0: | ||
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) | ||
else: | ||
right_tensor_size = tensor_size | ||
right_server_size = right_tensor_size // size | ||
|
||
# Adding bias to the initialization of the gradient we are communicating | ||
# In order to get rid of the case where some elements in the gradient are too small | ||
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank | ||
|
||
worker_error = torch.zeros(right_tensor_size, device=device) | ||
server_error = torch.zeros(right_server_size, device=device) | ||
|
||
a_torch, worker_error_torch, server_error_torch = torch_sim(a) | ||
get_accelerator().empty_cache() | ||
|
||
a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) | ||
|
||
print(a_torch.cpu()) | ||
print(a_after.cpu()) | ||
|
||
threshold = 1e-6 | ||
magnitude_threshold = 1e-6 | ||
diff_mask = (a_after - a_torch) > threshold | ||
diff_server_mask = torch.chunk(diff_mask, size)[rank] | ||
mpi_server = torch.chunk(a_after, size)[rank] + server_error | ||
torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch | ||
|
||
test_correctness = True | ||
|
||
# If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic | ||
# The test would skip those numbers that are too small in compensated_server_m | ||
if test_correctness: | ||
if torch.sum(diff_server_mask) == 0: | ||
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) | ||
else: | ||
check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold | ||
if torch.sum(check_mag_mask) == 0: | ||
print('Successfully passed the test for Compressed Backend at Rank {}'.format(rank)) | ||
else: | ||
print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import deepspeed.comm as dist | ||
import numpy as np | ||
import argparse | ||
import deepspeed | ||
import os | ||
|
||
from deepspeed.runtime.comm.compressed import CompressedBackend | ||
from deepspeed.utils.timer import SynchronizedWallClockTimer | ||
from deepspeed.accelerator import get_accelerator | ||
from statistics import mean | ||
|
||
timers = SynchronizedWallClockTimer() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--local_rank', type=int, default=-1) | ||
args = parser.parse_args() | ||
|
||
deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) | ||
args.local_rank = int(os.environ['LOCAL_RANK']) | ||
|
||
get_accelerator().set_device(args.local_rank) | ||
device = torch.device(get_accelerator().device_name(), args.local_rank) | ||
|
||
size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
|
||
backend = CompressedBackend() | ||
local_rank = args.local_rank | ||
|
||
# Setting tensor_size (BERT-Large) | ||
tensor_size = 300 * 2**20 | ||
server_size = int(tensor_size / size) | ||
if tensor_size % (8 * size) != 0: | ||
right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) | ||
else: | ||
right_tensor_size = tensor_size | ||
right_server_size = right_tensor_size // size | ||
|
||
# Adding bias to the initialization of the gradient we are communicating | ||
# In order to get rid of the case where some elements in the gradient are too small | ||
a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank | ||
|
||
worker_error = torch.zeros(right_tensor_size, device=device) | ||
server_error = torch.zeros(right_server_size, device=device) | ||
|
||
warmup = 10 | ||
iters = 10 | ||
|
||
# Warmup | ||
for i in range(warmup): | ||
backend.compressed_allreduce(a, worker_error, server_error, local_rank) | ||
|
||
time_list = [] | ||
|
||
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) | ||
scale = a.norm() / np.sqrt(a.numel()) | ||
a_compressed = scale * a_sign | ||
|
||
print("Shape of the compressed buffer:", a_compressed.shape) if rank == 0 else None | ||
|
||
for i in range(iters): | ||
timers('compressed_allreduce').start() | ||
backend.compressed_allreduce(a, worker_error, server_error, local_rank) | ||
#deepspeed.comm.all_reduce(a_compressed) | ||
timers('compressed_allreduce').stop() | ||
time_list.append(timers('compressed_allreduce').elapsed()) | ||
|
||
#timer_names = ['compressed_allreduce'] | ||
#timers.log(names=timer_names, normalizer=1, memory_breakdown=None) | ||
|
||
places = 2 | ||
convert = 1e3 | ||
float_size = 4 | ||
|
||
if rank == 0: | ||
for i in range(iters): | ||
lat = time_list[i] | ||
print("latency = ", lat * convert) | ||
|
||
minlat = round(min(time_list) * convert) | ||
maxlat = round(max(time_list) * convert) | ||
meanlat = round(mean(time_list) * convert, places) | ||
print("min, max, and mean = {} ms, {} ms, {} ms".format(minlat, maxlat, meanlat)) if rank == 0 else None | ||
#print("tensor shape", a.shape) | ||
duration = meanlat / 1e3 | ||
tput = ((tensor_size * 4) / duration) | ||
print("algo throughput: %f Bytes/s, %f GB/s" % (tput, tput / 1e9)) if rank == 0 else None | ||
size = tensor_size * 4 | ||
n = dist.get_world_size() | ||
busbw = (size / duration) * (2 * (n - 1) / n) | ||
print("busbw: %f GB/s" % (busbw / 1e9)) if rank == 0 else None |