Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/alg3
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Feb 20, 2024
2 parents d064259 + a3d0799 commit aba92bf
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 89 deletions.
68 changes: 56 additions & 12 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import mpi4py
import numpy as np

from mscclpp.utils import is_torch_tensor


class CommGroup:
def __init__(
Expand Down Expand Up @@ -86,15 +88,18 @@ def make_connection(
) -> dict[int, Connection]:
if type(endpoints) is Transport:
endpoints = EndpointConfig(endpoints)
if endpoints.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoints)
elif type(endpoints) is dict:
endpoints = {k: EndpointConfig(v) if type(v) is Transport else v for k, v in endpoints.items()}
connections = {}
for rank in all_ranks:
if type(endpoints) is dict:
endpoint = endpoints[rank]
else:
endpoint = endpoints
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
if endpoint.transport == Transport.Nvls:
return self.communicator.connct_nvls_collective(all_ranks, endpoint)
else:
connections[rank] = self.communicator.connect_on_setup(rank, 0, endpoint)
self.communicator.setup()
connections = {rank: connections[rank].get() for rank in connections}
return connections
Expand All @@ -105,8 +110,15 @@ def register_tensor_with_connections(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
Expand All @@ -133,20 +145,24 @@ def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor.data.ptr)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
return channels

def make_sm_channels_with_scratch(
self, tensor: cp.ndarray, scratchTensor: cp.ndarray, connections: dict[int, Connection]
self,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(
semaphores[rank], registered_memories[rank], tensor.data.ptr, scratchTensor.data.ptr
)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
return channels

def make_proxy_channels(
Expand Down Expand Up @@ -177,8 +193,15 @@ def make_proxy_channels_with_scratch(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
Expand All @@ -197,3 +220,24 @@ def make_proxy_channels_with_scratch(
proxy_service.proxy_channel(semaphore_ids[rank]), memory_ids[rank], memory_ids[self.my_rank]
)
return channels

def register_semaphore_with_proxy(
self, proxy_service: ProxyService, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
semaphore_ids = {}
for rank in semaphores:
semaphore_ids[rank] = proxy_service.add_semaphore(semaphores[rank])
channels = {}
for rank in semaphores:
channels[rank] = proxy_service.proxy_channel(semaphore_ids[rank])
return channels

def register_memory_with_proxy(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
) -> dict[int, int]:
registered_memories = self.register_tensor_with_connections(tensor, connections)
memory_ids = {}
for rank in registered_memories:
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
return memory_ids
127 changes: 67 additions & 60 deletions python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,57 @@
import struct
import subprocess
import tempfile
from typing import Type
from typing import Any, Type

from cuda import cuda, nvrtc, cudart
import cupy as cp
import numpy as np

try:
import torch

def _check_cuda_errors(result):
if result[0].value:
raise RuntimeError(f"CUDA error code={result[0].value}({_cuda_get_error(result[0])})")
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]


def _cuda_get_error(error):
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, cudart.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
_use_torch = True
torchTensor = torch.Tensor
except ImportError:
_use_torch = False
torchTensor = Type[Any]


class Kernel:
def __init__(self, ptx: bytes, kernel_name: str, device_id: int):
self._context = _check_cuda_errors(cuda.cuCtxGetCurrent())
assert self._context is not None
self._module = _check_cuda_errors(cuda.cuModuleLoadData(ptx))
self._kernel = _check_cuda_errors(cuda.cuModuleGetFunction(self._module, kernel_name.encode()))
CU_LAUNCH_PARAM_BUFFER_POINTER = 0x01
CU_LAUNCH_PARAM_BUFFER_SIZE = 0x02
CU_LAUNCH_PARAM_END = 0x00 if not cp.cuda.runtime.is_hip else 0x03

def __init__(self, ptx: bytes, kernel_name: str):
self._module = cp.cuda.driver.moduleLoadData(ptx)
self._kernel = cp.cuda.driver.moduleGetFunction(self._module, kernel_name)

def launch_kernel(
self,
params: bytes,
nblocks: int,
nthreads: int,
shared: int,
stream: Type[cuda.CUstream] or Type[cudart.cudaStream_t],
stream: Type[cp.cuda.Stream] or Type[None],
):
buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params)
buffer_size = ctypes.c_size_t(len(params))
config = np.array(
[
cuda.CU_LAUNCH_PARAM_BUFFER_POINTER,
Kernel.CU_LAUNCH_PARAM_BUFFER_POINTER,
ctypes.addressof(buffer),
cuda.CU_LAUNCH_PARAM_BUFFER_SIZE,
Kernel.CU_LAUNCH_PARAM_BUFFER_SIZE,
ctypes.addressof(buffer_size),
cuda.CU_LAUNCH_PARAM_END,
Kernel.CU_LAUNCH_PARAM_END,
],
dtype=np.uint64,
)
_check_cuda_errors(
cuda.cuLaunchKernel(self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, stream, 0, config.ctypes.data)
cuda_stream = stream.ptr if stream else 0
cp.cuda.driver.launchKernel(
self._kernel, nblocks, 1, 1, nthreads, 1, 1, shared, cuda_stream, 0, config.ctypes.data
)

def __del__(self):
cuda.cuModuleUnload(self._module)
cp.cuda.driver.moduleUnload(self._module)


class KernelBuilder:
Expand All @@ -87,35 +75,48 @@ def __init__(self, file: str, kernel_name: str, file_dir: str = None, macro_dict
self.macros = None
if file_dir:
self.macros = ["-D{}={}".format(macro, value) for macro, value in macro_dict.items()]
device_id = cp.cuda.Device().id
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx", device_id)
self._kernel = Kernel(ptx, kernel_name, device_id)
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx")
self._kernel = Kernel(ptx, kernel_name)
self.kernel_map[kernel_key] = self._kernel

def _compile_cuda(self, source_file, output_file, device_id, std_version="c++17"):
def _compile_cuda(self, source_file, output_file, std_version="c++17"):
mscclpp_home = os.environ.get("MSCCLPP_HOME", "/usr/local/mscclpp")
include_dir = os.path.join(mscclpp_home, "include")
major = _check_cuda_errors(
cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device_id)
)
minor = _check_cuda_errors(
cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device_id)
)
cuda_home = os.environ.get("CUDA_HOME")
nvcc = os.path.join(cuda_home, "bin/nvcc") if cuda_home else "nvcc"
command = [
nvcc,
f"-std={std_version}",
"-ptx",
"-Xcompiler",
"-Wall,-Wextra",
f"-I{include_dir}",
f"{source_file}",
f"--gpu-architecture=compute_{major}{minor}",
f"--gpu-code=sm_{major}{minor},compute_{major}{minor}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
if not cp.cuda.runtime.is_hip:
compute_capability = cp.cuda.Device().compute_capability
cuda_home = os.environ.get("CUDA_HOME")
nvcc = os.path.join(cuda_home, "bin/nvcc") if cuda_home else "nvcc"
command = [
nvcc,
f"-std={std_version}",
"-ptx",
"-Xcompiler",
"-Wall,-Wextra",
f"-I{include_dir}",
f"{source_file}",
f"--gpu-architecture=compute_{compute_capability}",
f"--gpu-code=sm_{compute_capability},compute_{compute_capability}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
else:
# the gcn arch name is like "gfx942:sramecc+:xnack-"
gcn_arch = (
cp.cuda.runtime.getDeviceProperties(cp.cuda.Device().id)["gcnArchName"].decode("utf-8").split(":")[0]
)
rocm_home = os.environ.get("ROCM_HOME")
hipcc = os.path.join(rocm_home, "bin/hipcc") if rocm_home else "hipcc"
command = [
hipcc,
f"-std={std_version}",
"--genco",
"-D__HIP_PLATFORM_AMD__",
f"--offload-arch={gcn_arch}",
f"-I{include_dir}",
f"{source_file}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
if self.macros:
command += self.macros
try:
Expand Down Expand Up @@ -145,6 +146,8 @@ def pack(*args):
res += struct.pack("P", arg.ctypes.data)
elif isinstance(arg, cp.ndarray):
res += struct.pack("P", arg.data.ptr)
elif is_torch_tensor(arg):
res += struct.pack("P", arg.data_ptr())
# use int to represent bool, which can avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES error
elif isinstance(arg, bool):
res += struct.pack("i", arg)
Expand All @@ -153,3 +156,7 @@ def pack(*args):
else:
raise RuntimeError(f"Unsupported type: {type(arg)}")
return res


def is_torch_tensor(tensor: Any) -> bool:
return _use_torch and isinstance(tensor, torchTensor)
4 changes: 4 additions & 0 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_fp16.h>
#endif

#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/nvls_device.hpp>
Expand Down
7 changes: 3 additions & 4 deletions python/mscclpp_benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check_correctness(memory, func, niter=100):
for p in range(niter):
memory[:] = cp.ones(memory.shape).astype(data_type) * (p * MPI.COMM_WORLD.size + MPI.COMM_WORLD.rank)
cp.cuda.runtime.deviceSynchronize()
output_memory = func(0)
output_memory = func(None)
cp.cuda.runtime.deviceSynchronize()
expected = cp.zeros_like(memory)
for i in range(MPI.COMM_WORLD.size):
Expand All @@ -110,7 +110,7 @@ def bench_time(niter: int, func):
with stream:
stream.begin_capture()
for i in range(niter):
func(stream.ptr)
func(stream)
graph = stream.end_capture()

# now run a warm up round
Expand Down Expand Up @@ -165,9 +165,8 @@ def run_benchmark(
memory_out = cp.zeros(nelem, dtype=data_type)
cp.cuda.runtime.deviceSynchronize()

proxy_service = None
proxy_service = ProxyService()
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
proxy_service = ProxyService()
if memory.nbytes < 2**20:
mscclpp_algos = [MscclppAllReduce2(mscclpp_group, memory, memory_out)]
else:
Expand Down
Loading

0 comments on commit aba92bf

Please sign in to comment.