From 846fdd2764e63eb10d2c40dfa9af4285d5f3840b Mon Sep 17 00:00:00 2001 From: Adrian Wang <123616592+cehongwang@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:26:11 -0700 Subject: [PATCH] Added tensor_parallelism examples (#3047) Co-authored-by: cehongwang --- examples/distributed_inference/README.md | 4 +- .../distributed_inference/llama3_model.py | 538 ++++++++++++++++++ .../tensor_parallel_llama3.py | 72 +++ .../tensor_parallel_simple_example.py | 99 ++++ 4 files changed, 712 insertions(+), 1 deletion(-) create mode 100644 examples/distributed_inference/llama3_model.py create mode 100644 examples/distributed_inference/tensor_parallel_llama3.py create mode 100755 examples/distributed_inference/tensor_parallel_simple_example.py diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md index c164e9581f..4c88570b6f 100644 --- a/examples/distributed_inference/README.md +++ b/examples/distributed_inference/README.md @@ -11,4 +11,6 @@ See the examples started with `data_parallel` for more details. 2. Tensor parallel distributed inference -In development. +Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. + +torchrun --nproc_per_node=2 tensor_parallel_llama2.py diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py new file mode 100644 index 0000000000..9fa59b5c49 --- /dev/null +++ b/examples/distributed_inference/llama3_model.py @@ -0,0 +1,538 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning + + +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + device: str = "cuda" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class RMSNorm(nn.Module): + """Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + """Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + def init_weights(self, init_std: float) -> None: + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> Any: + """Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x) -> Any: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float) -> None: + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """TransformerBlock Module. + + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + return h + self.feed_forward(self.ffn_norm(h)) + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class ParallelTransformer(nn.Module): + """Transformer Module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None): + # Here we use distributed model initialization to avoid memory overflow + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.tok_embeddings.to(model_args.device) + self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", + self._precompute_freqs_cis().to(model_args.device), + persistent=True, + ) + + self.layers = torch.nn.ModuleDict().to(model_args.device) + for layer_id in range(model_args.n_layers): + block = TransformerBlock(layer_id, model_args).to(model_args.device) + self.layers[str(layer_id)] = block + self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh) + + self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to( + model_args.device + ) + self.norm = self.parallel_norm(self.norm, tp_mesh) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to( + model_args.device + ) + self.output = self.parallel_output(self.output, tp_mesh) + self.init_weights() + + def parallel_transformer_block(self, transformer_block, tp_mesh): + if tp_mesh.size() <= 1: + return + plan = { + "attention": PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + # Apply the plan for the current transformer block + parallelize_module(transformer_block, tp_mesh, plan) + + def parallel_embeddings(self, embedding, tp_mesh): + plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + } + return parallelize_module(embedding, tp_mesh, plan) + + def parallel_output(self, output, tp_mesh): + plan = { + "output": ColwiseParallel( + input_layouts=Shard(1), + ), + } + return parallelize_module(output, tp_mesh, plan) + + def parallel_norm(self, norm, tp_mesh): + plan = { + "norm": SequenceParallel(), + } + return parallelize_module(norm, tp_mesh, plan) + + def reset_parameters(self): + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + + def init_weights(self): + """[Note: On ``init_weights`` vs. + + ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + + """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + layer.init_weights() + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + return self.output(h).float() if self.output else h + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """Initialize a Transformer model from a ModelArgs object. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py new file mode 100644 index 0000000000..fc03a64386 --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -0,0 +1,72 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +import logging +import os +import time + +import torch +import torch_tensorrt +from llama3_model import ModelArgs, ParallelTransformer +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh + +_rank = int(os.environ["RANK"]) +_world_size = int(os.environ["WORLD_SIZE"]) +tp_size = 2 + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w") +fh.setLevel(logging.INFO) +logger.addHandler(fh) + +tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) + +model_args = ModelArgs( + vocab_size=32000, + dim=1024, + n_layers=4, + n_heads=8, + rope_theta=500000.0, + n_kv_heads=8, + device="cuda", +) + +with torch.no_grad(): + model = ParallelTransformer(model_args, tp_mesh) + torch.manual_seed(0) + inp = torch.randint(32000, (8, 256), device="cuda") + python_result = model(inp) + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + model = torch.compile( + model, + fullgraph=True, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "workspace_size": 1 << 33, + "debug": False, + "timing_cache_path": "/opt/file/cache/timing_cache_llama.bin", + }, + dynamic=False, + ) + for i in range(15): + # seeding with dp_rank to ensure identical inputs for TP groups + torch.manual_seed(i) + start = time.time() + output = model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py new file mode 100755 index 0000000000..470487a751 --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -0,0 +1,99 @@ +import os +import sys +import time + +import torch +import torch.nn as nn +import torch_tensorrt +from torch.distributed._tensor import Shard +from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + +""" +This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py +""" + + +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 3200) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(3200, 1600) + self.in_proj2 = nn.Linear(1600, 500) + self.out_proj2 = nn.Linear(500, 100) + + def forward(self, x): + x = self.out_proj(self.relu(self.in_proj(x))) + x = self.relu(x) + x = self.out_proj2(self.relu(self.in_proj2(x))) + return x + + +# create a device mesh based on the given world_size. +_world_size = int(os.environ["WORLD_SIZE"]) + +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) +_rank = device_mesh.get_rank() + + +print(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" + + +# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") + + +# Custom parallelization plan for the model +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + "in_proj2": ColwiseParallel(input_layouts=Shard(0)), + "out_proj2": RowwiseParallel(output_layouts=Shard(0)), + }, +) +torch.manual_seed(0) +inp = torch.rand(20, 10, device="cuda") +python_result = tp_model(inp) + + +backend = "torch_tensorrt" +tp_model = torch.compile( + tp_model, + backend=backend, + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + }, + dynamic=False, +) + +for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + print(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + print(f"Inference time is {end-start}")