Skip to content

Commit

Permalink
[Hardware][TPU] Support MoE with Pallas GMM kernel (#6457)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 16, 2024
1 parent 9f4ccec commit c467dff
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 8 deletions.
4 changes: 3 additions & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
ARG NIGHTLY_DATE="20240601"
ARG NIGHTLY_DATE="20240713"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
WORKDIR /workspace

# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install NumPy 1 instead of NumPy 2.
RUN pip install "numpy<2"
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240601"
$ export DATE="+20240713"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
Expand Down Expand Up @@ -85,7 +85,7 @@ Next, build vLLM from source. This will only take a few seconds:
ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory
You can install OpenBLAS with the following command:
Please install OpenBLAS with the following command:

.. code-block:: console
Expand Down
18 changes: 18 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")

def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)


class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
Expand Down
62 changes: 62 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}")

hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)

topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens,
device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)

# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)

x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)

x = x * topk_weights.unsqueeze_(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x
9 changes: 4 additions & 5 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,10 @@ def _get_padded_prefill_len(x: int) -> int:


def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
elif batch_size <= 8:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
Expand Down

0 comments on commit c467dff

Please sign in to comment.