Skip to content

Commit aa58a9c

Browse files
authored
Integrate MoE kernel for torchax path (#996)
Signed-off-by: Siyuan Liu <[email protected]>
1 parent aad4c55 commit aa58a9c

File tree

5 files changed

+262
-66
lines changed

5 files changed

+262
-66
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ steps:
128128
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/kernels \
129129
--ignore=/workspace/tpu_inference/tests/kernels/ragged_paged_attention_kernel_v2_test.py \
130130
--ignore=/workspace/tpu_inference/tests/kernels/ragged_kv_cache_update_v2_test.py \
131-
--ignore=/workspace/tpu_inference/tests/kernels/collectives
131+
--ignore=/workspace/tpu_inference/tests/kernels/collectives \
132+
--ignore=/workspace/tpu_inference/tests/kernels/fused_moe_v1_test.py
132133
else
133134
echo "Skipping: no changes detected in kernels, tests/kernels, or requirements.txt"
134135
exit 0

tests/kernels/fused_moe_v1_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import jax
22
import jax.numpy as jnp
3+
import numpy as np
34
from absl.testing import absltest
45
from jax._src import test_util as jtu
56
from jax.sharding import Mesh
@@ -59,7 +60,8 @@ def setUp(self):
5960
(-1 if x.coords[0] % 2 else 1) * x.coords[1],
6061
),
6162
)
62-
self.mesh = Mesh(devices=self.mesh_devices, axis_names=("model", ))
63+
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
64+
axis_names=("data", "model"))
6365

6466
def test_basic(self):
6567
dtype = jnp.bfloat16

tests/layers/vllm/test_unquantized.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,123 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
483483
atol=1e-2,
484484
rtol=1e-2,
485485
)
486+
487+
488+
@pytest.mark.parametrize("use_ep", [True])
489+
@pytest.mark.parametrize("mesh",
490+
[test_utils.get_spmd_mesh(jax.local_device_count())])
491+
@pytest.mark.parametrize("num_tokens", [128, 512])
492+
@pytest.mark.parametrize("intermediate_size", [256, 512])
493+
@pytest.mark.parametrize("hidden_size", [256])
494+
@pytest.mark.parametrize("num_experts", [32])
495+
@pytest.mark.parametrize("topk", [2])
496+
def test_fused_moe_use_kernel(use_ep, mesh, num_tokens, intermediate_size,
497+
hidden_size, num_experts, topk):
498+
499+
if jax.local_device_count() < 8:
500+
pytest.skip("Test requires at least 8 devices")
501+
502+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
503+
torch.manual_seed(42)
504+
dtype = torch.bfloat16
505+
506+
a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
507+
w1 = torch.randn(
508+
(num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
509+
w2 = torch.randn(
510+
(num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
511+
512+
# Use deterministic gating_output generation (same logic as fused_moe_v1_test.py)
513+
# Generate base gating scores with deterministic pattern
514+
score = (
515+
torch.randn((num_tokens, num_experts), dtype=torch.float32) +
516+
torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(
517+
num_tokens, num_experts) / 100)
518+
519+
# Generate unique top-k indices
520+
generator = torch.Generator()
521+
generator.manual_seed(42)
522+
top_k_indices = torch.randint(0,
523+
num_experts - 1, (num_tokens, topk),
524+
dtype=torch.int32,
525+
generator=generator)
526+
527+
# Add one-hot encoding weighted by 10 to ensure selected experts have highest scores
528+
one_hot = torch.nn.functional.one_hot(top_k_indices.long(),
529+
num_classes=num_experts).float()
530+
one_hot = one_hot.sum(dim=1) * 10
531+
532+
score = (score + one_hot).to(dtype)
533+
534+
torch_output = torch_moe(
535+
hidden_states=a,
536+
w1=w1,
537+
w2=w2,
538+
gating_output=score,
539+
topk=topk,
540+
global_num_experts=num_experts,
541+
expert_map=None,
542+
renormalize=False,
543+
)
544+
545+
engine_args = EngineArgs(
546+
model="Qwen/Qwen2-1.5B-Instruct",
547+
max_model_len=64,
548+
max_num_batched_tokens=64,
549+
max_num_seqs=4,
550+
)
551+
vllm_config = engine_args.create_engine_config()
552+
vllm_config.model_config.dtype = dtype
553+
vllm_config.parallel_config = ParallelConfig(
554+
tensor_parallel_size=mesh.devices.size, enable_expert_paralle=use_ep)
555+
556+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
557+
with set_current_vllm_config(vllm_config):
558+
vllm_fused_moe = FusedMoE(
559+
num_experts=num_experts,
560+
top_k=topk,
561+
hidden_size=hidden_size,
562+
intermediate_size=intermediate_size,
563+
reduce_results=True,
564+
renormalize=False,
565+
tp_size=mesh.devices.size,
566+
dp_size=1,
567+
quant_config=quant_config,
568+
)
569+
vllm_fused_moe.moe_parallel_config.use_ep = use_ep
570+
571+
vllm_fused_moe.w13_weight.data = w1
572+
vllm_fused_moe.w2_weight.data = w2
573+
574+
p_spec = P('model', )
575+
jax_a = torch_view(t2j(a, use_dlpack=False))
576+
jax_a = jax_a.apply_jax_(jax.device_put, NamedSharding(mesh, p_spec))
577+
score = torch_view(t2j(score))
578+
score = score.apply_jax_(jax.device_put, NamedSharding(mesh, p_spec))
579+
580+
with torchax.default_env(), set_forward_context(None, vllm_config):
581+
assert isinstance(vllm_fused_moe.quant_method,
582+
VllmUnquantizedFusedMoEMethod)
583+
# Enable the kernel for this test
584+
vllm_fused_moe.quant_method.use_kernel = True
585+
vllm_fused_moe.quant_method.process_weights_after_loading(
586+
vllm_fused_moe)
587+
vllm_fused_moe.quant_method.block_size = {
588+
"bt": 32,
589+
"bf": 512,
590+
"bd1": 512,
591+
"bd2": 512,
592+
"btc": 32,
593+
"bfc": 256,
594+
"bd1c": 256,
595+
"bd2c": 256,
596+
}
597+
jax_output = vllm_fused_moe(jax_a, score)
598+
jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
599+
600+
torch.testing.assert_close(
601+
torch_output,
602+
jax_output,
603+
atol=1e-2,
604+
rtol=1e-2,
605+
)

tpu_inference/kernels/fused_moe/v1/kernel.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from jax import lax
88
from jax._src import dtypes
99
from jax.experimental import pallas as pl
10+
from jax.experimental import shard_map
1011
from jax.experimental.pallas import tpu as pltpu
1112

1213
P = jax.sharding.PartitionSpec
@@ -144,7 +145,7 @@ def _fused_ep_moe_kernel(
144145
a2a_acc_sem,
145146
*,
146147
top_k: int,
147-
ep_name: str,
148+
ep_axis_name: str,
148149
# Kernel tuning params.
149150
bt: int, # Block size of local_num_tokens.
150151
bf: int, # Block size of intermediate_size.
@@ -155,8 +156,8 @@ def _fused_ep_moe_kernel(
155156
bd1c: int, # Compute size of block hidden_size.
156157
bd2c: int, # Compute size of block hidden_size.
157158
):
158-
my_id = lax.axis_index(ep_name)
159-
num_devices = lax.axis_size(ep_name)
159+
my_id = lax.axis_index(ep_axis_name)
160+
num_devices = lax.axis_size(ep_axis_name)
160161
local_num_tokens = tokens_hbm.shape[0]
161162
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
162163
# num_experts = local_num_experts * num_devices
@@ -186,8 +187,8 @@ def sync_barrier():
186187
barrier_sem = pltpu.get_barrier_semaphore()
187188
pltpu.semaphore_signal(
188189
barrier_sem,
189-
device_id=right_id,
190-
device_id_type=pltpu.DeviceIdType.LOGICAL,
190+
device_id=(0, right_id),
191+
device_id_type=pltpu.DeviceIdType.MESH,
191192
)
192193
pltpu.semaphore_wait(barrier_sem, 1)
193194

@@ -276,7 +277,7 @@ def _all_reduce_metadata(
276277
dst_ref=d2e_count_vmem.at[row_id],
277278
send_sem=send_sem,
278279
recv_sem=recv_sem,
279-
device_id=(right_id, ),
280+
device_id=(0, right_id),
280281
device_id_type=pltpu.DeviceIdType.MESH,
281282
).wait()
282283
row_id = (row_id + num_devices - 1) % num_devices
@@ -358,7 +359,10 @@ def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
358359
pl.ds(start, remote_sz)],
359360
send_sem=send_sems.at[e_sem_id],
360361
recv_sem=recv_sems.at[e_sem_id],
361-
device_id=(recv_id, ),
362+
device_id=(
363+
0,
364+
recv_id,
365+
),
362366
).start()
363367
a2a_s_sends_x2_smem[e_sem_id] = send_sz
364368

@@ -402,7 +406,7 @@ def start_a2a_gather(bt_id, e_sem_id, local_e_id):
402406
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
403407
send_sem=send_sems.at[e_sem_id],
404408
recv_sem=a2a_gather_sem,
405-
device_id=(recv_id, ),
409+
device_id=(0, recv_id),
406410
).start()
407411
start += sz
408412

@@ -831,6 +835,7 @@ def _():
831835
"bfc",
832836
"bd1c",
833837
"bd2c",
838+
"ep_axis_name",
834839
],
835840
)
836841
def fused_ep_moe(
@@ -850,12 +855,14 @@ def fused_ep_moe(
850855
bfc: int,
851856
bd1c: int,
852857
bd2c: int,
858+
ep_axis_name: str = 'model',
853859
):
854-
if len(mesh.axis_names) != 1:
855-
raise ValueError("Mesh must have only one axis")
860+
# Assert all other axes have length of 1
861+
assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862+
assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863+
"Expect data axis size of 1 in tpu-inference"
856864

857-
ep_name = mesh.axis_names[0]
858-
ep_size = mesh.axis_sizes[0]
865+
ep_size = mesh.shape[ep_axis_name]
859866
num_devices = ep_size
860867

861868
num_tokens, actual_hidden_size = tokens.shape
@@ -907,7 +914,7 @@ def fused_ep_moe(
907914
functools.partial(
908915
_fused_ep_moe_kernel,
909916
top_k=top_k,
910-
ep_name=ep_name,
917+
ep_axis_name=ep_axis_name,
911918
bt=bt,
912919
bf=bf,
913920
bd1=bd1,
@@ -999,11 +1006,13 @@ def fused_ep_moe(
9991006
))
10001007

10011008
@jax.jit
1002-
@jax.shard_map(
1009+
@functools.partial(
1010+
shard_map.shard_map,
10031011
mesh=mesh,
1004-
in_specs=(P(ep_name), P(ep_name), P(ep_name), P(ep_name), P()),
1005-
out_specs=P(ep_name),
1006-
check_vma=False,
1012+
in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013+
P(ep_axis_name), P()),
1014+
out_specs=P(ep_axis_name),
1015+
check_rep=False,
10071016
)
10081017
def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
10091018
return fused_moe(

0 commit comments

Comments
 (0)