diff --git a/python/sgl_jax/srt/layers/gmm/auto_tune_tiling.py b/python/sgl_jax/srt/layers/gmm/auto_tune_tiling.py new file mode 100644 index 00000000..f341ac05 --- /dev/null +++ b/python/sgl_jax/srt/layers/gmm/auto_tune_tiling.py @@ -0,0 +1,277 @@ +import functools +import json +import logging +import os +import time +from typing import List, Optional, Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +from sgl_jax.srt.layers.gmm.megablox_gmm_backend import gmm + +logger = logging.getLogger(__name__) + + +class TilingAutoTuner: + def __init__(self, cache_dir: str = "/tmp/tune_cache"): + self.cache_dir = cache_dir + os.makedirs(cache_dir, exist_ok=True) + + def _get_cache_key(self, m: int, k: int, n: int, num_groups: int) -> str: + return f"m{m}_k{k}_n{n}_g{num_groups}" + + def _get_cache_file(self, cache_key: str) -> str: + return os.path.join(self.cache_dir, f"{cache_key}.json") + + def _load_cached_result(self, cache_key: str) -> Optional[Tuple[int, int, int]]: + cache_file = self._get_cache_file(cache_key) + if os.path.exists(cache_file): + try: + with open(cache_file, "r") as f: + data = json.load(f) + return tuple(data["optimal_tiling"]) + except Exception: + pass + return None + + def _save_cached_result( + self, cache_key: str, optimal_tiling: Tuple[int, int, int], best_time: float + ): + cache_file = self._get_cache_file(cache_key) + data = { + "optimal_tiling": list(optimal_tiling), + "best_time_ms": best_time * 1000, + "timestamp": time.time(), + } + with open(cache_file, "w") as f: + json.dump(data, f, indent=2) + + def _create_test_data( + self, + m: int, + k: int, + n: int, + num_groups: int, + dtype: jnp.dtype = jnp.bfloat16, + seed: int = 42, + ): + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, 2) + + lhs = jax.random.normal(keys[0], (m, k), dtype=dtype) + rhs = jax.random.normal(keys[1], (num_groups, k, n), dtype=dtype) + group_sizes = jnp.array([m // num_groups] * num_groups, dtype=jnp.int32) + + return lhs, rhs, group_sizes + + def _benchmark_tiling( + self, + lhs, + rhs, + group_sizes, + tiling: Tuple[int, int, int], + num_warmup: int = 1, + num_trials: int = 3, + ) -> float: + @functools.partial(jax.jit, static_argnames=["tiling"]) + def jitted_gmm(lhs, rhs, group_sizes, tiling): + return gmm( + lhs, rhs, group_sizes, preferred_element_type=jnp.float32, tiling=tiling + ) + + # Warmup + for _ in range(num_warmup): + out = jitted_gmm(lhs, rhs, group_sizes, tiling) + jax.block_until_ready(out) + + times = [] + for _ in range(num_trials): + start = time.perf_counter() + out = jitted_gmm(lhs, rhs, group_sizes, tiling) + jax.block_until_ready(out) + times.append(time.perf_counter() - start) + + return np.mean(times) + + def _generate_tiling_candidates( + self, m: int, k: int, n: int + ) -> List[Tuple[int, int, int]]: + tile_sizes_m = [ + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + ] # m can be small for decode + tile_sizes_k = [128, 256, 512, 1024, 2048] # k >= 128 for TPU + tile_sizes_n = [128, 256, 512, 1024, 2048] # n >= 128 for TPU + + candidates = [] + + for tm in tile_sizes_m: + if tm > m: + continue + for tk in tile_sizes_k: + if tk > k: + continue + for tn in tile_sizes_n: + if tn > n: + continue + + # GMM constraint: dimensions must be divisible by tile sizes + if m % tm != 0 or k % tk != 0 or n % tn != 0: + continue + + # TPU constraints: check effective dimensions (min of tile_size and actual dimension) + effective_tk = min(tk, k) + effective_tn = min(tn, n) + + # TPU requires: k dimension divisible by 8, n dimension divisible by 128 + if effective_tk % 8 != 0 or effective_tn % 128 != 0: + continue + + candidates.append((tm, tk, tn)) + + default_tm = 8 # Start with small value for decode compatibility + default_tk = 128 # Start with TPU-safe minimum + default_tn = 128 # Start with TPU-safe minimum + + # Find the largest tm that divides m (including smaller values for decode) + for tm in tile_sizes_m: + if tm <= m and m % tm == 0: + default_tm = tm + + # Find the largest tk that divides k and meets TPU constraints + for tk in reversed(tile_sizes_k): + if tk <= k and k % tk == 0: + default_tk = tk + break + + # Find the largest tn that divides n and meets TPU constraints + for tn in reversed(tile_sizes_n): + if tn <= n and n % tn == 0: + default_tn = tn + break + + default_tiling = (default_tm, default_tk, default_tn) + if default_tiling not in candidates and all(d > 0 for d in default_tiling): + candidates.append(default_tiling) + + candidates.sort(key=lambda x: (x[0] * x[1] * x[2], x[0], x[1], x[2])) + + return candidates + + def _format_failure_summary(self, failure_reasons: dict) -> str: + """Format failure reasons into a readable summary.""" + if not failure_reasons: + return "None" + + summary_parts = [] + for error_type, details in failure_reasons.items(): + count = details["count"] + examples = details["examples"] + if count == 1 and examples: + summary_parts.append(f"{error_type}(1): {examples[0]}") + else: + example_str = f" e.g. {examples[0]}" if examples else "" + summary_parts.append(f"{error_type}({count}){example_str}") + + return "; ".join(summary_parts) + + def tune_for_target_size( + self, + m: int, + k: int, + n: int, + num_groups: int, + use_cache: bool = True, + ) -> Tuple[int, int, int]: + cache_key = self._get_cache_key(m, k, n, num_groups) + + if use_cache: + cached_result = self._load_cached_result(cache_key) + if cached_result is not None: + logger.debug(f"Using cached tiling for {cache_key}: {cached_result}") + return cached_result + + logger.debug( + f"Tuning tiling for problem size: m={m}, k={k}, n={n}, groups={num_groups}" + ) + + lhs, rhs, group_sizes = self._create_test_data(m, k, n, num_groups) + + candidates = self._generate_tiling_candidates(m, k, n) + + best_tiling = None + best_time = float("inf") + failed_count = 0 + failure_reasons = {} # Track failure reasons + + for i, tiling in enumerate(candidates): + try: + avg_time = self._benchmark_tiling(lhs, rhs, group_sizes, tiling) + if avg_time < best_time: + best_time = avg_time + best_tiling = tiling + + except Exception as e: + failed_count += 1 + error_type = type(e).__name__ + error_msg = str(e) + if error_type not in failure_reasons: + failure_reasons[error_type] = {"count": 0, "examples": []} + failure_reasons[error_type]["count"] += 1 + if len(failure_reasons[error_type]["examples"]) < 3: + failure_reasons[error_type]["examples"].append( + f"{tiling}: {error_msg}" + ) + logger.debug(f"Tiling {tiling} failed: {error_type}: {error_msg}") + continue + + if best_tiling is None: + fallback_tm = 8 # Start with small value for decode compatibility + fallback_tk = 128 # Start with TPU-safe minimum + fallback_tn = 128 # Start with TPU-safe minimum + + tile_sizes_m = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + tile_sizes_k = [128, 256, 512, 1024, 2048] + tile_sizes_n = [128, 256, 512, 1024, 2048] + + for tm in tile_sizes_m: + if tm <= m and m % tm == 0: + fallback_tm = tm + + for tk in reversed(tile_sizes_k): + if tk <= k and k % tk == 0: + fallback_tk = tk + break + + for tn in reversed(tile_sizes_n): + if tn <= n and n % tn == 0: + fallback_tn = tn + break + + best_tiling = (fallback_tm, fallback_tk, fallback_tn) + failure_summary = self._format_failure_summary(failure_reasons) + logger.warning( + f"[GMM AUTO-TUNE] All {len(candidates)} tiling candidates failed for problem (m={m}, k={k}, n={n}, groups={num_groups}), using default {best_tiling}. " + f"Failure reasons: {failure_summary}" + ) + else: + if failed_count > 0: + failure_summary = self._format_failure_summary(failure_reasons) + logger.warning( + f"[GMM AUTO-TUNE] {failed_count}/{len(candidates)} tiling candidates failed for problem (m={m}, k={k}, n={n}, groups={num_groups}). " + f"Failure reasons: {failure_summary}" + ) + + if use_cache: + self._save_cached_result(cache_key, best_tiling, best_time) + + return best_tiling diff --git a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py index 42bcf714..c04acfcb 100644 --- a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py +++ b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py @@ -512,7 +512,9 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): ), input_output_aliases=input_output_aliases, compiler_params=pltpu.CompilerParams( - dimension_semantics=("parallel", "arbitrary", "arbitrary") + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + vmem_limit_bytes=int(64 * (1 << 20)), + disable_bounds_checks=True, ), interpret=interpret, cost_estimate=cost_estimate, diff --git a/python/sgl_jax/srt/layers/gmm/tiling_manager.py b/python/sgl_jax/srt/layers/gmm/tiling_manager.py new file mode 100644 index 00000000..13f0f9bc --- /dev/null +++ b/python/sgl_jax/srt/layers/gmm/tiling_manager.py @@ -0,0 +1,152 @@ +import json +import os +import threading +from typing import Dict, List, Optional, Tuple + + +def get_default_cache_dir() -> str: + return os.environ.get("GMM_TUNE_CACHE_DIR", "/tmp/tune_cache") + + +class TilingManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, cache_dir: Optional[str] = None): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, cache_dir: Optional[str] = None): + if self._initialized: + return + + self.cache_dir = cache_dir or get_default_cache_dir() + self.tiling_cache: Dict[str, Tuple[int, int, int]] = {} + self.default_tiling = (8, 1024, 1024) + self._load_all_cached_tilings() + self._initialized = True + + def _get_cache_key(self, m: int, k: int, n: int, num_groups: int) -> str: + return f"m{m}_k{k}_n{n}_g{num_groups}" + + def _load_all_cached_tilings(self): + if not os.path.exists(self.cache_dir): + return + + for filename in os.listdir(self.cache_dir): + if filename.endswith(".json"): + cache_key = filename[:-5] # Remove .json extension + cache_file = os.path.join(self.cache_dir, filename) + + try: + with open(cache_file, "r") as f: + data = json.load(f) + if "optimal_tiling" in data: + self.tiling_cache[cache_key] = tuple(data["optimal_tiling"]) + except Exception: + continue + + def get_optimal_tiling( + self, m: int, k: int, n: int, num_groups: int + ) -> Tuple[int, int, int]: + cache_key = self._get_cache_key(m, k, n, num_groups) + + if cache_key in self.tiling_cache: + return self.tiling_cache[cache_key] + + # Try to find a close match with same k, n, num_groups but different m + # This is common when batch size varies but model dimensions stay the same + for cached_key, tiling in self.tiling_cache.items(): + parts = cached_key.split("_") + if len(parts) == 4: + try: + cached_m = int(parts[0][1:]) # Remove 'm' prefix + cached_k = int(parts[1][1:]) # Remove 'k' prefix + cached_n = int(parts[2][1:]) # Remove 'n' prefix + cached_groups = int(parts[3][1:]) # Remove 'g' prefix + + if ( + cached_k == k + and cached_n == n + and cached_groups == num_groups + and ( + abs(cached_m - m) / max(cached_m, m) < 0.5 + or min(cached_m, m) <= 256 + ) + ): + return tiling + except ValueError: + continue + + return self.default_tiling + + +_global_tiling_manager = None + + +def get_tiling_manager(cache_dir: Optional[str] = None) -> TilingManager: + global _global_tiling_manager + if _global_tiling_manager is None: + _global_tiling_manager = TilingManager(cache_dir) + return _global_tiling_manager + + +def get_optimal_tiling_for_gmm( + m: int, k: int, n: int, num_groups: int = 1 +) -> Tuple[int, int, int]: + manager = get_tiling_manager() + return manager.get_optimal_tiling(m, k, n, num_groups) + + +def load_all_gmm_tiling_configs() -> Dict[str, List[int]]: + """Load all auto-tune GMM tiling configurations into memory after auto-tune completes.""" + import json + import os + + configs = {} + cache_dir = get_default_cache_dir() + + if not os.path.exists(cache_dir): + print(f"[TilingManager] No auto-tune cache directory found at {cache_dir}") + return configs + + loaded_count = 0 + + # Load all auto-tune results from cache files + for filename in os.listdir(cache_dir): + if not filename.endswith(".json"): + continue + + cache_file = os.path.join(cache_dir, filename) + try: + with open(cache_file, "r") as f: + data = json.load(f) + + if "optimal_tiling" not in data: + continue + + # Parse cache key: m{m}_k{k}_n{n}_g{num_groups} + cache_key = filename[:-5] # Remove .json extension + parts = cache_key.split("_") + if len(parts) != 4: + continue + + try: + # Store in memory cache using string key format + configs[cache_key] = list(data["optimal_tiling"]) + loaded_count += 1 + + except ValueError: + continue + + except Exception: + continue + + print( + f"[TilingManager] Loaded {loaded_count} GMM tiling configurations into memory" + ) + return configs diff --git a/python/sgl_jax/srt/layers/moe.py b/python/sgl_jax/srt/layers/moe.py index 54463182..73ad1841 100644 --- a/python/sgl_jax/srt/layers/moe.py +++ b/python/sgl_jax/srt/layers/moe.py @@ -173,7 +173,7 @@ def _detect_device_capabilities(self): except Exception as e: return False, "cpu" - def __call__(self, inputs, router_logits=None): + def __call__(self, inputs, router_logits=None, gmm_tiling_config_array=None): if router_logits is None: raise ValueError("router_logits is required for EPMoE") @@ -186,15 +186,26 @@ def __call__(self, inputs, router_logits=None): ) if self.expert_parallel_size == 1: - output = self._single_device_forward(inputs, router_logits) + output = self._single_device_forward( + inputs, router_logits, gmm_tiling_config_array + ) else: - output = self._expert_parallel_forward_with_shard_map(inputs, router_logits) + output = self._expert_parallel_forward_with_shard_map( + inputs, router_logits, gmm_tiling_config_array + ) return output - def _expert_parallel_forward_with_shard_map(self, inputs, router_logits): + def _expert_parallel_forward_with_shard_map( + self, inputs, router_logits, gmm_tiling_config_array + ): def _internal_moe_computation( - hidden_states, router_logits, w0_weights, w1_weights, wo_weights + hidden_states, + router_logits, + gmm_tiling_config_array, + w0_weights, + w1_weights, + wo_weights, ): data_index = jax.lax.axis_index("data") tensor_index = jax.lax.axis_index("tensor") @@ -243,6 +254,7 @@ def _internal_moe_computation( w0_weights, w1_weights, wo_weights, + gmm_tiling_config_array, ) # EP Combine @@ -268,37 +280,51 @@ def _internal_moe_computation( in_specs=( P(None), # hidden_states P(None), # router_logits + P(None), # gmm_tiling_config_array P(("data", "tensor"), None, None), # w0_weights P(("data", "tensor"), None, None), # w1_weights P(("data", "tensor"), None, None), # wo_weights ), out_specs=P(None), check_rep=False, - )(inputs, router_logits, self.wi_0.value, self.wi_1.value, self.wo.value) + )( + inputs, + router_logits, + gmm_tiling_config_array, + self.wi_0.value, + self.wi_1.value, + self.wo.value, + ) def _gmm_compute_with_sharded_weights( - self, x, local_group_sizes, selected_experts, w0_kernel, w1_kernel, wo_kernel + self, + x, + local_group_sizes, + selected_experts, + w0_kernel, + w1_kernel, + wo_kernel, + gmm_tiling_config_array, ): if x.shape[0] == 0: empty_output = jnp.zeros( (0, wo_kernel.shape[-1]), dtype=x.dtype ) # (0, hidden_dim) return empty_output - - m, k = x.shape[0], x.shape[1] - n_gate = w0_kernel.shape[2] - n_down = wo_kernel.shape[2] - - default_tile_size = (512, 1024, 1024) - tiling_gate = ( - min(default_tile_size[0], m), - min(default_tile_size[1], k), - min(default_tile_size[2], n_gate), - ) - tiling_down = ( - min(default_tile_size[0], m), - min(default_tile_size[1], n_gate), - min(default_tile_size[2], n_down), + jax.debug.print("x_shape: {x_shape}", x_shape=x.shape) + # static_tiling_gate = ( + # int(optimal_tiling_gate[0]), + # int(optimal_tiling_gate[1]), + # int(optimal_tiling_gate[2]), + # ) + # static_tiling_down = ( + # int(optimal_tiling_down[0]), + # int(optimal_tiling_down[1]), + # int(optimal_tiling_down[2]), + # ) + jax.debug.print( + "gmm_tiling_array: {gmm_tiling_config_array}", + gmm_tiling_config_array=gmm_tiling_config_array, ) # gate layer_w0 = gmm( @@ -306,7 +332,7 @@ def _gmm_compute_with_sharded_weights( rhs=w0_kernel, group_sizes=local_group_sizes, preferred_element_type=self.dtype, - tiling=tiling_gate, + tiling=gmm_tiling_config_array[0], ) # up layer_w1 = gmm( @@ -314,7 +340,7 @@ def _gmm_compute_with_sharded_weights( rhs=w1_kernel, group_sizes=local_group_sizes, preferred_element_type=self.dtype, - tiling=tiling_gate, + tiling=gmm_tiling_config_array[0], ) # activation @@ -327,12 +353,12 @@ def _gmm_compute_with_sharded_weights( rhs=wo_kernel, group_sizes=local_group_sizes, preferred_element_type=self.dtype, - tiling=tiling_down, + tiling=gmm_tiling_config_array[0], ) return intermediate_output - def _single_device_forward(self, inputs, router_logits): + def _single_device_forward(self, inputs, router_logits, gmm_tiling_config_array): top_k_logits, top_k_indices = jax.lax.top_k( router_logits, self.num_experts_per_tok ) @@ -342,9 +368,13 @@ def _single_device_forward(self, inputs, router_logits): top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True) - return self._single_device_forward(inputs, top_k_indices, top_k_weights) + return self._single_device_forward_impl( + inputs, top_k_indices, top_k_weights, gmm_tiling_config_array + ) - def _single_device_forward(self, inputs, top_k_indices, top_k_weights): + def _single_device_forward_impl( + self, inputs, top_k_indices, top_k_weights, gmm_tiling_config_array + ): num_tokens = inputs.shape[0] * (inputs.shape[1] if inputs.ndim > 1 else 1) inputs_flat = inputs.reshape(num_tokens, -1) diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index eeeb05c2..2c9c9426 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -20,7 +20,7 @@ import logging import threading from http import HTTPStatus -from typing import Any, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np from jax import numpy as jnp @@ -1025,6 +1025,7 @@ def get_model_worker_batch( bs_paddings: list, cache_loc_paddings: List, page_size: int, + gmm_tiling_configs: Optional[Dict[str, List[int]]] = None, ) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None @@ -1244,6 +1245,7 @@ def get_model_worker_batch( real_bs=real_bs, capture_hidden_mode=CaptureHiddenMode.NULL, launch_done=self.launch_done, + gmm_tiling_configs=gmm_tiling_configs, ) def _generate_trace_info(self, real_bs: int, bid: int) -> List[str]: @@ -1350,6 +1352,8 @@ class ModelWorkerBatch: capture_hidden_mode: CaptureHiddenMode = None + gmm_tiling_config_array: np.ndarray = None + # For logits and logprobs post processing temp_scaled_logprobs: bool = False temperature: np.ndarray = None diff --git a/python/sgl_jax/srt/managers/scheduler.py b/python/sgl_jax/srt/managers/scheduler.py index d8fd9d2d..3830fa3e 100644 --- a/python/sgl_jax/srt/managers/scheduler.py +++ b/python/sgl_jax/srt/managers/scheduler.py @@ -306,6 +306,11 @@ def __init__( ] ) + if not server_args.disable_gmm_auto_tune: + logger.info(f"[Scheduler] Begins to run GMM auto-tuning.") + self.tp_worker.run_gmm_auto_tune() + logger.info(f"[Scheduler] Completes GMM auto-tuning.") + if not server_args.disable_jax_precompile: logger.info(f"[Scheduler] Begins to run worker precompile.") self.tp_worker.run_precompile() @@ -908,6 +913,7 @@ def run_batch(self, batch: ScheduleBatch) -> Union[GenerationBatchResult]: precompile_bs_paddings, precompile_cache_loc_paddings, self.page_size, + self.tp_worker.get_gmm_tiling_configs(), ) sampling_metadata = SamplingMetadata.from_model_worker_batch( diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 6622b41d..48baf306 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -2,6 +2,10 @@ import itertools import logging +import os + +# Add import for GMM auto tune +import sys import threading import time from typing import Optional, Tuple, Union @@ -14,6 +18,11 @@ from tqdm import tqdm from sgl_jax.srt.configs.model_config import ModelConfig +from sgl_jax.srt.layers.gmm.auto_tune_tiling import TilingAutoTuner +from sgl_jax.srt.layers.gmm.tiling_manager import ( + get_default_cache_dir, + load_all_gmm_tiling_configs, +) from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput from sgl_jax.srt.managers.schedule_batch import ( ModelWorkerBatch, @@ -82,6 +91,9 @@ def __init__( rngs=nnx.Rngs(self.random_seed), ) + # Initialize empty GMM tiling configs (will be populated after auto-tune) + self.gmm_tiling_configs = {} + # set infer devices self.device = server_args.device @@ -184,6 +196,162 @@ def normalize_token_paddings(self): self.precompile_token_paddings = normalized_token_paddings + def run_gmm_auto_tune(self): + start_time = time.perf_counter() + logger.info("[GMM AUTO-TUNE] Starting GMM tiling parameter auto-tuning") + + try: + hf_config = self.model_config.hf_config + + # check if this is a MoE model + self.num_experts = getattr(hf_config, "num_experts", None) + if self.num_experts is None: + self.disable_gmm_auto_tune = True + return + self.disable_gmm_auto_tune = False + logger.info( + f"[GMM AUTO-TUNE] MoE model detected with {self.num_experts} experts" + ) + + self.hidden_size = getattr(hf_config, "hidden_size", 4096) + self.intermediate_size = getattr(hf_config, "intermediate_size", None) + self.moe_intermediate_size = getattr( + hf_config, "moe_intermediate_size", None + ) + + target_intermediate_size = ( + self.moe_intermediate_size + or self.intermediate_size + or self.hidden_size * 4 + ) + + logger.info( + f"[GMM AUTO-TUNE] Model config: hidden_size={self.hidden_size}, " + f"intermediate_size={target_intermediate_size}, num_experts={self.num_experts}" + ) + + self.num_experts_per_tok = getattr(hf_config, "num_experts_per_tok", 8) + # Generate all combinations: m = batch_size * seq_length * num_experts_per_tok + logger.info( + f"[GMM AUTO-TUNE] Using precompile parameters for shape generation" + ) + logger.info( + f"[GMM AUTO-TUNE] Batch size paddings: {self.precompile_bs_paddings}" + ) + logger.info( + f"[GMM AUTO-TUNE] Sequence length paddings: {self.precompile_token_paddings}" + ) + logger.info( + f"[GMM AUTO-TUNE] Experts per token: {self.num_experts_per_tok}" + ) + logger.info( + f"[GMM AUTO-TUNE] Max padded num tokens: {self.max_padded_num_tokens}" + ) + + shapes = [] + skipped_count = 0 + + logger.info(f"[GMM AUTO-TUNE] Adding decode-specific shapes (seq_length=1)") + for batch_size in self.precompile_bs_paddings: + seq_length = 1 + actual_tokens = batch_size * seq_length # = batch_size + + # Decode: m = batch_size * num_experts_per_tok + m = actual_tokens * self.num_experts_per_tok + + if m < 8 or self.hidden_size < 64 or target_intermediate_size < 64: + continue + + # Add shapes for decode gate/up projections (hidden -> intermediate) + shapes.append( + (m, self.hidden_size, target_intermediate_size, self.num_experts) + ) + # Add shapes for decode down projections (intermediate -> hidden) + shapes.append( + (m, target_intermediate_size, self.hidden_size, self.num_experts) + ) + + logger.info(f"[GMM AUTO-TUNE] Adding prefill shapes (variable seq_length)") + total_combinations = len(self.precompile_bs_paddings) + ( + len(self.precompile_bs_paddings) * len(self.precompile_token_paddings) + ) + for batch_size in self.precompile_bs_paddings: + for seq_length in self.precompile_token_paddings: + actual_tokens = batch_size * seq_length + + m = actual_tokens * self.num_experts_per_tok + + if m < 64 or self.hidden_size < 64 or target_intermediate_size < 64: + skipped_count += 1 + logger.debug( + f"[GMM AUTO-TUNE] Skipping tiny shape bs={batch_size}, seq={seq_length} -> m={m}, k={self.hidden_size}, n={target_intermediate_size}" + ) + continue + + shapes.append( + ( + m, + self.hidden_size, + target_intermediate_size, + self.num_experts, + ) + ) + + shapes.append( + ( + m, + target_intermediate_size, + self.hidden_size, + self.num_experts, + ) + ) + + logger.info( + f"[GMM AUTO-TUNE] Generated {len(shapes)} shape configurations from {total_combinations} combinations " + f"(skipped {skipped_count} due to size constraints)" + ) + if not shapes: + logger.warning("[GMM AUTO-TUNE] No valid shapes to tune, skipping") + return + + cache_dir = get_default_cache_dir() + tuner = TilingAutoTuner(cache_dir=cache_dir) + logger.info(f"[GMM AUTO-TUNE] Using cache directory: {cache_dir}") + + results = {} + total_shapes = len(shapes) + logger.info(f"[GMM AUTO-TUNE] Tuning {total_shapes} shape configurations") + + with tqdm(shapes, desc="[GMM AUTO-TUNE] Progress", leave=False) as pbar: + for i, (m, k, n, num_groups) in enumerate(pbar): + pbar.set_postfix(shape=f"m={m},k={k},n={n},g={num_groups}") + + optimal_tiling = tuner.tune_for_target_size( + m, k, n, num_groups, use_cache=True + ) + cache_key = tuner._get_cache_key(m, k, n, num_groups) + results[cache_key] = optimal_tiling + + end_time = time.perf_counter() + logger.info( + f"[GMM AUTO-TUNE] Completed in {end_time - start_time:.1f} seconds" + ) + logger.info( + f"[GMM AUTO-TUNE] Successfully tuned {len(results)}/{total_shapes} configurations" + ) + + # Load all GMM tiling configurations into memory for fast access + logger.info("[GMM AUTO-TUNE] Loading tiling configurations into memory") + self.gmm_tiling_configs = load_all_gmm_tiling_configs() + + except Exception as e: + logger.error(f"[GMM AUTO-TUNE] Auto-tuning failed: {e}") + logger.info( + "[GMM AUTO-TUNE] Continuing without auto-tuning, will use default tiling parameters" + ) + # Initialize empty configs for fallback + self.gmm_tiling_configs = {} + def run_precompile(self): self.precompile_extend() self.precompile_decode() @@ -272,6 +440,9 @@ def get_precompile_paddings(self): self.precompile_cache_loc_paddings, ) + def get_gmm_tiling_configs(self): + return self.gmm_tiling_configs + def generate_model_worker_batch( self, bs: int, @@ -291,6 +462,25 @@ def generate_model_worker_batch( valid_cache_loc = np.arange(bs) invalid_cache_loc = np.array([0] * (invalid_cache_loc_size), dtype=jnp.int32) + gmm_tiling_config_array = np.zeros((1, 3), dtype=np.int32) + + if not self.disable_gmm_auto_tune: + tiling_key = f"m{bs * num_tokens* self.num_experts_per_tok}_k{self.hidden_size}_n{self.moe_intermediate_size}_g{self.num_experts}" + # gmm_tiling_config_array = self.gmm_tiling_configs.get( + # tiling_key, + # None, + # ) + gmm_tiling_config_array[0] = self.gmm_tiling_configs.get( + tiling_key, + [512, 1024, 1024], + ) + jax.debug.print("tiling_key: {tiling_key}", tiling_key=tiling_key) + jax.debug.print( + "gmm_tiling_config_array: {gmm_tiling_config_array}", + gmm_tiling_config_array=gmm_tiling_config_array, + ) + else: + gmm_tiling_config_array = None return ModelWorkerBatch( bid=1, @@ -319,6 +509,7 @@ def generate_model_worker_batch( token_ids_logprobs=None, extend_logprob_start_lens=None, capture_hidden_mode=CaptureHiddenMode.NULL, + gmm_tiling_config_array=gmm_tiling_config_array, ) def get_model_runner(self): diff --git a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py index e5dc0928..58263fbe 100644 --- a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py +++ b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py @@ -245,5 +245,11 @@ def run_precompile(self): ) self.worker.run_precompile() + def run_gmm_auto_tune(self): + self.worker.run_gmm_auto_tune() + + def get_gmm_tiling_configs(self): + return self.worker.get_gmm_tiling_configs() + def __delete__(self): self.input_queue.put((None, None, None, None)) diff --git a/python/sgl_jax/srt/model_executor/forward_batch_info.py b/python/sgl_jax/srt/model_executor/forward_batch_info.py index bff82d37..785a776f 100644 --- a/python/sgl_jax/srt/model_executor/forward_batch_info.py +++ b/python/sgl_jax/srt/model_executor/forward_batch_info.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from enum import IntEnum, auto from functools import total_ordering -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import jax @@ -163,6 +163,8 @@ class ForwardBatch: trace_request_ids: Optional[List[str]] = None trace_request_objects: Optional[List] = None + gmm_tiling_config_array: jax.Array = None + def tree_flatten(self): children = ( self.input_ids, @@ -176,6 +178,7 @@ def tree_flatten(self): self.cache_loc, self.extend_prefix_lens, self.extend_seq_lens, + self.gmm_tiling_config_array, ) aux_data = { @@ -204,6 +207,7 @@ def tree_unflatten(cls, aux_data, children): obj.cache_loc = children[8] obj.extend_prefix_lens = children[9] obj.extend_seq_lens = children[10] + obj.gmm_tiling_config_array = children[11] return obj @@ -220,6 +224,7 @@ def __repr__(self) -> str: "cache_loc", "extend_prefix_lens", "extend_seq_lens", + "gmm_tiling_config_array", ]: value = getattr(self, field_name, None) if value is not None and isinstance(value, jax.Array): @@ -244,6 +249,7 @@ def init_new( cache_loc, extend_prefix_lens, extend_seq_lens, + gmm_tiling_config_array, ) = device_array( model_runner.mesh, ( @@ -256,8 +262,10 @@ def init_new( batch.cache_loc, batch.extend_prefix_lens, batch.extend_seq_lens, + batch.gmm_tiling_config_array, ), ) + obj = cls( bid=batch.bid, forward_mode=batch.forward_mode, @@ -273,6 +281,7 @@ def init_new( extend_seq_lens=extend_seq_lens, token_to_kv_pool=model_runner.token_to_kv_pool, attn_backend=model_runner.attn_backend, + gmm_tiling_config_array=gmm_tiling_config_array, ) return obj diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index f4aaa08f..83bd7520 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -252,10 +252,13 @@ def __call__( hidden_states += residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - if self.is_moe_layer: router_logits = self.moe_gate(hidden_states) - mlp_output = self.mlp(hidden_states, router_logits=router_logits) + mlp_output = self.mlp( + hidden_states, + router_logits=router_logits, + gmm_tiling_config_array=forward_batch.gmm_tiling_config_array, + ) hidden_states = mlp_output else: hidden_states = self.mlp(hidden_states) diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 478cc8af..2eebc573 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -128,6 +128,7 @@ class ServerArgs: precompile_bs_paddings: Optional[List[int]] = None disable_jax_precompile: bool = False + disable_gmm_auto_tune: bool = False def __post_init__(self): # Set missing default values @@ -738,6 +739,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="whether disable jax precompile", ) + parser.add_argument( + "--disable-gmm-auto-tune", + action="store_true", + help="Disable automatic tuning of GMM (Grouped Matrix Multiplication) tiling parameters at startup", + ) # Kernel backend parser.add_argument( "--attention-backend",