Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using nccl ops from TRT-LLM namespace #3250

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ See the examples started with `data_parallel` for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin

apt install libmpich-dev
apt install libopenmpi-dev
pip install tensorrt-llm
#then pip install the tensorrt and torch version compatible with installed torchTRT
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

4. Tensor parallel distributed llama3 inference using nccl ops plugin

apt install libmpich-dev
apt install libopenmpi-dev
pip install tensorrt-llm
#then pip install the tensorrt and torch version compatible with installed torchTRT
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
4 changes: 3 additions & 1 deletion examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
accelerate
transformers
diffusers
diffusers
site
tensorrt-llm
21 changes: 8 additions & 13 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,20 @@
import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_nccl_ops import register_nccl_ops
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
tp_size = 2
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")

logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)

tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
Expand All @@ -38,7 +33,7 @@
)

with torch.no_grad():
model = ParallelTransformer(model_args, tp_mesh)
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -53,7 +48,7 @@
"use_python_runtime": True,
"workspace_size": 1 << 33,
"debug": False,
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand Down
185 changes: 185 additions & 0 deletions examples/distributed_inference/tensor_parallel_nccl_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import ctypes
import logging
import os
import site
from enum import IntEnum, IntFlag, auto
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
import torch.distributed as dist
import torch_tensorrt
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.fx import GraphModule, Node
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
custom_fused_all_gather_op,
custom_fused_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name


# class for AllReduce
class AllReduceStrategy(IntEnum):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.

They must be kept in sync.
"""

NCCL = 0
ONESHOT = 1
TWOSHOT = 2
AUTO = 3


class AllReduceConfig(IntFlag):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.

They must be kept in sync
"""

USE_MEMCPY = auto()
PUSH_MODE = auto()


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

return local_rank, world_size


def register_nccl_ops(logger_file_name):
# Initialization
initialize_distributed_env()
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])
Copy link
Collaborator

@narendasan narendasan Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things like this I am ok pulling in "globally", since we can assume the env variable is set and presumably this is what people are doing aready


device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
logger = initialize_logger(_rank, logger_file_name)
device_id = (
_rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

# TensorRT NCCL plugins
# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for debugging purposes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to see if the the plugins with "tensorrt_llm" namespace have been loaded properly or not

for plugin_creator in plugin_registry.plugin_creator_list:
logger.info(
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
)

@dynamo_tensorrt_converter(custom_fused_all_gather_op)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to start thinking about how these might get added as actual converters like how we support quantization. I think the global variable dependency is a issue. How might we work around that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes pulling in the global variable assuming that the environment variable is set and initialized in the initialization part can be done, instead of using the dist package

def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
_world_size = int(os.environ["WORLD_SIZE"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

group = list(range(_world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
p_dtype = trt.float32
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
set_layer_name(layer, target, name)
return layer.get_output(0)

@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"ReduceScatter", "1", "tensorrt_llm"
)

assert allreduce_plg_creator is not None

counter = 0
strategy = AllReduceStrategy.NCCL
config = AllReduceConfig(0)
_world_size = int(os.environ["WORLD_SIZE"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some error handling for if this is not set?

group = list(range(_world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)

p_dtype = trt.float16
pf_dtype = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = [group, pf_dtype]
p_strategy = trt.PluginField(
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_strategy)
p_config = trt.PluginField(
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_config)
p_counter = trt.PluginField(
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
)
pfc.append(p_counter)

pfc = trt.PluginFieldCollection(pfc)
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)

layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
set_layer_name(layer, target, name)
return layer.get_output(0)

return device_mesh, _world_size, _rank, logger
24 changes: 11 additions & 13 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import sys
import time

import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_nccl_ops import register_nccl_ops
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = register_nccl_ops(
"./tensor_parallel_simple_example"
)

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand All @@ -36,14 +40,7 @@ def forward(self, x):
return x


# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


print(f"Starting PyTorch TP example on rank {_rank}.")
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
Expand Down Expand Up @@ -78,6 +75,7 @@ def forward(self, x):
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1,
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand All @@ -91,9 +89,9 @@ def forward(self, x):
output = tp_model(inp)
end = time.time()
if i == 0:
print(f"Compilation time is {end-start}")
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
print(f"Inference time is {end-start}")
logger.info(f"Inference time is {end-start}")
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
USE_EXPLICIT_TYPING = False
USE_FP32_ACC = False
ENABLE_WEIGHT_STREAMING = False
USE_AOT_JOINT_EXPORT = True


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SPARSE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
Expand Down Expand Up @@ -84,6 +85,7 @@ class CompilationSettings:
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool): Enable weight streaming.
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -121,6 +123,7 @@ class CompilationSettings:
use_explicit_typing: bool = USE_EXPLICIT_TYPING
use_fp32_acc: bool = USE_FP32_ACC
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
Loading
Loading