diff --git a/python/sgl_jax/srt/layers/attention/flashattention_backend.py b/python/sgl_jax/srt/layers/attention/flashattention_backend.py index 0865d07c..4041b5b2 100644 --- a/python/sgl_jax/srt/layers/attention/flashattention_backend.py +++ b/python/sgl_jax/srt/layers/attention/flashattention_backend.py @@ -101,14 +101,14 @@ def get_forward_metadata(self, batch: ModelWorkerBatch): cu_q_lens = np.concatenate( [ np.array([0], dtype=np.int32), - np.cumsum(batch.extend_seq_lens), + np.cumsum(batch.extend_seq_lens, dtype=np.int32), ] ) elif batch.forward_mode == ForwardMode.DECODE: - cu_q_lens = jnp.concatenate( + cu_q_lens = np.concatenate( [ - np.array([0], dtype=jnp.int32), - np.cumsum(jnp.ones(len(batch.seq_lens), dtype=np.int32)), + np.array([0], dtype=np.int32), + np.cumsum(np.ones(len(batch.seq_lens), dtype=np.int32)), ] ) else: diff --git a/python/sgl_jax/srt/layers/sampler.py b/python/sgl_jax/srt/layers/sampler.py index fc57bc0d..d1fce35b 100644 --- a/python/sgl_jax/srt/layers/sampler.py +++ b/python/sgl_jax/srt/layers/sampler.py @@ -18,14 +18,14 @@ def __init__(self, rngs: nnx.Rngs = None): def _greedy_sampling(self, operands): """Greedy sampling branch""" - logits, _, _ = operands + logits, _, _, _ = operands batch_next_token_ids = jnp.argmax(logits, -1).flatten() logprobs = jax.nn.log_softmax(logits, axis=-1) return batch_next_token_ids, logprobs def _regular_sampling(self, operands): """Regular sampling branch""" - logits, sampling_metadata, rng = operands + logits, sampling_metadata, positions, rng = operands # Post process logits processed_logits = jnp.divide(logits, sampling_metadata.temperatures).astype( @@ -39,6 +39,8 @@ def _regular_sampling(self, operands): sampling_metadata.top_ks, sampling_metadata.top_ps, sampling_metadata.min_ps, + positions, + sampling_metadata.sampling_seeds, sampling_metadata.need_min_p_sampling, rng, ) @@ -80,18 +82,14 @@ def __call__( self, logits_output: LogitsProcessorOutput, sampling_metadata: SamplingMetadata, + positions: jax.Array, ): """Run a sampler & compute logprobs and update logits_output accordingly. Args: logits_output: The logits from the model forward - sampling_info: Metadata for sampling - return_logprob: If set, store the output logprob information to - logits_output - top_logprobs_nums: Number of top lobprobs per sequence in a batch - batch_next_token_ids: next token IDs. If set, skip sampling and only - compute output logprobs It is used for speculative decoding which - performs sampling in draft workers. + sampling_metadata: Metadata for sampling + positions: The positions of the tokens in the sequence. """ logits = jnp.reshape( @@ -101,7 +99,7 @@ def __call__( _, rng = jax.random.split(self.rngs.params()) - operands = (logits, sampling_metadata, rng) + operands = (logits, sampling_metadata, positions, rng) batch_next_token_ids, logprobs = lax.cond( sampling_metadata.is_all_greedy, self._greedy_sampling, @@ -158,19 +156,75 @@ def top_k_top_p_min_p_sampling_from_probs_jax( top_ks: jax.Array, top_ps: jax.Array, min_ps: jax.Array, - need_min_p_sampling: bool, - rng: nnx.Rngs, + positions: jax.Array, + sampling_seeds: jax.Array = None, + need_min_p_sampling: bool = False, + rng: nnx.Rngs = None, ): """A top-k, top-p and min-p sampling implementation with native jax operations.""" probs_sort, probs_idx = _sample_part_a( probs, top_ks, top_ps, min_ps, need_min_p_sampling ) - sampled_index = random.categorical(rng, jnp.log(probs_sort)).reshape(-1, 1) + multinomial_operands = (probs_sort, sampling_seeds, positions, rng) + sampled_index = lax.cond( + sampling_seeds is not None, + multinomial_with_seed, + multinomial, + multinomial_operands, + ) return _sample_part_b(probs_idx, sampled_index) +def multinomial( + operands, +) -> jax.Array: + inputs, _, _, rng = operands + return random.categorical(rng, jnp.log(inputs)).reshape(-1, 1) + + +def multinomial_with_seed( + operands, +) -> jax.Array: + """ + Note: + 1. This implementation is copied from https://github.com/sgl-project/sglang/blob/e2ac7888b8cb1fd6c33a7ec58d27a5f5b5b24e0c/python/sglang/srt/layers/sampler.py#L268. + 2. Based on last response in issue, the fixed four big prime numbers can be set freely. 8589934591 is out of uin32, so I replace it with 805306457. + - issue: https://github.com/sgl-project/sglang/issues/10938 + + Samples n elements from an input array `inputs` of shape (n, m) using + a unique random seed for each row. + + Args: + inputs: A float array of shape (n, m) representing n categorical + distributions with m categories each. The values are treated + as weights and do not need to sum to 1. + seed: An integer array of shape (n,) containing the random seed + for each corresponding row in `inputs`. + positions: The positions of the tokens in the sequence. + + Returns: + A array of shape (n,) where the i-th element is an index sampled + from the distribution in `inputs[i]` using `seed[i]`. + """ + inputs, seed, positions, _ = operands + if seed is None: + # note: this codes is used to keep compatible with lax.cond + return multinomial(operands) + n, m = inputs.shape + step_seed = seed * 19349663 ^ positions * 73856093 + seed_expanded = step_seed[:, None] + col_indices = jnp.arange(m)[None, :] + hashed = seed_expanded * 805306457 ^ col_indices * 479001599 + uniform_samples = (hashed % (2**24)).astype(jnp.float32) / (2**24) + epsilon = 1e-9 + gumbel_noise = -jnp.log(-jnp.log(uniform_samples + epsilon) + epsilon) + log_probs = jnp.log(inputs + epsilon) + perturbed_log_probs = log_probs + gumbel_noise + return jnp.argmax(perturbed_log_probs, axis=1, keepdims=True) + + def _apply_min_p_filter(operands): """Apply min_p filtering when need_min_p_sampling=True""" probs_sort, min_ps = operands diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index 44245d46..b9a62a47 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -47,6 +47,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "device", "disable_radix_cache", + "enable_deterministic_sampling", ] PADDING_BUCKETS = [1 << i for i in range(6, 21)] diff --git a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py index b4989a43..20a17b4b 100644 --- a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py @@ -142,6 +142,12 @@ def process_batch_result_prefill( logprob_pt += num_input_logprobs batch.cache_miss_count = cache_miss_count + + if batch.cache_miss_count > 0: + logger.info( + f"Prefill batch. #bid: {result.bid}, #cache_miss: {cache_miss_count}" + ) + self.stream_output( batch.reqs, batch.return_logprob, skip_stream_req, cache_miss_count ) diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 27811807..c58b2bbc 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -11,6 +11,7 @@ import numpy as np from flax import nnx from jax.experimental.multihost_utils import broadcast_one_to_all +from jax.sharding import NamedSharding, PartitionSpec from tqdm import tqdm from sgl_jax.srt.configs.model_config import ModelConfig @@ -33,6 +34,7 @@ PRECOMPILE_DEFAULT_BS_PADDINGS, PRECOMPILE_DEFAULT_TOKEN_PADDINGS, ) +from sgl_jax.srt.utils.jax_utils import device_array logger = logging.getLogger(__name__) @@ -212,6 +214,9 @@ def precompile_extend(self, future_token_ids_map=None): ForwardMode.EXTEND, self.precompile_cache_loc_paddings[-1], ) + sampling_metadata = SamplingMetadata.from_model_worker_batch( + model_worker_batch, 0, self.mesh + ) model_worker_batch.forward_batch = ForwardBatch.init_new( model_worker_batch, self.model_runner ) @@ -222,8 +227,10 @@ def precompile_extend(self, future_token_ids_map=None): future_token_ids_map, ) ) - self.forward_batch_generation(model_worker_batch, None, True) + self.forward_batch_generation( + model_worker_batch, None, False, sampling_metadata + ) end_time = time.perf_counter() logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time) @@ -266,6 +273,13 @@ def precompile_decode(self, future_token_ids_map=None): end_time = time.perf_counter() logger.info("[DECODE] Precompile finished in %.0f secs", end_time - start_time) + def set_forward_metadata(self, model_worker_batch: ModelWorkerBatch): + self.model_runner.attn_backend.forward_metadata = ( + self.worker.model_runner.attn_backend.get_forward_metadata( + model_worker_batch + ) + ) + def get_max_padded_size(self): """Calculate the max padded batch size and token nums. @@ -393,6 +407,21 @@ def forward_batch_generation( ) self.model_runner.attn_backend.forward_metadata = forward_metadata + # note: put positions on devices again because the forward_batch has been donated + if not skip_sample: + positions = ( + model_worker_batch.positions + if model_worker_batch.forward_mode.is_decode() + else model_worker_batch.seq_lens - 1 + ) + positions_device = device_array( + positions, + sharding=( + NamedSharding(self.model_runner.mesh, PartitionSpec()) + if jax.process_count() == 1 + else None + ), + ) logits_output, cache_miss_count = self.model_runner.forward( forward_batch, logits_metadata=LogitsMetadata.from_model_worker_batch( @@ -411,7 +440,9 @@ def forward_batch_generation( with jtu.count_pjit_cpp_cache_miss() as count: next_token_ids_device = self.model_runner.sample( - logits_output, sampling_metadata + logits_output, + sampling_metadata, + positions_device, ) sample_cache_miss_count = count() diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index c8ab62fa..7eb8460a 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -17,7 +17,11 @@ from sgl_jax.srt.configs.model_config import AttentionArch, MockModelConfig, ModelConfig from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput from sgl_jax.srt.layers.sampler import Sampler -from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch +from sgl_jax.srt.managers.schedule_batch import ( + GLOBAL_SERVER_ARGS_KEYS, + ModelWorkerBatch, + global_server_args_dict, +) from sgl_jax.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, @@ -83,6 +87,12 @@ def __init__( self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.forward_pass_id = 0 + + # Global vars + global_server_args_dict.update( + {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} + ) + self.model_loader = JAXModelLoader( load_config=LoadConfig( load_format=LoadFormat.JAX, download_dir=server_args.download_dir @@ -431,19 +441,21 @@ def sample( self, logits_output: LogitsProcessorOutput, sampling_metadata: SamplingMetadata, + positions: jax.Array, ) -> jax.Array: """Sample and compute logprobs and update logits_output. Args: logits_output: The logits output from the model forward forward_batch: The forward batch that generates logits_output - + positions: The positions of the tokens in the sequence. Returns: A list of next_token_ids """ return self.jitted_sampler( logits_output, sampling_metadata, + positions, ) diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index 7dde2154..6ecbbe3f 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -7,7 +7,8 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.tree_util import register_pytree_node_class -from sgl_jax.srt.sampling.sampling_params import TOP_K_ALL +from sgl_jax.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, TOP_K_ALL +from sgl_jax.srt.utils import get_bool_env_var from sgl_jax.srt.utils.jax_utils import device_array if TYPE_CHECKING: @@ -40,6 +41,7 @@ class SamplingMetadata: top_ps: jax.Array top_ks: jax.Array min_ps: jax.Array + sampling_seeds: jax.Array is_all_greedy: bool = False need_min_p_sampling: bool = False @@ -49,14 +51,15 @@ def tree_flatten(self): self.top_ps, self.top_ks, self.min_ps, + self.sampling_seeds, + self.is_all_greedy, + self.need_min_p_sampling, ) aux_data = { "return_logprob": self.return_logprob, "top_logprobs_nums": self.top_logprobs_nums, "token_ids_logprobs": self.token_ids_logprobs, - "is_all_greedy": self.is_all_greedy, - "need_min_p_sampling": self.need_min_p_sampling, } return (children, aux_data) @@ -68,12 +71,13 @@ def tree_unflatten(cls, aux_data, children): obj.top_ps = children[1] obj.top_ks = children[2] obj.min_ps = children[3] + obj.sampling_seeds = children[4] + obj.is_all_greedy = children[5] + obj.need_min_p_sampling = children[6] obj.return_logprob = aux_data["return_logprob"] obj.top_logprobs_nums = aux_data["top_logprobs_nums"] obj.token_ids_logprobs = aux_data["token_ids_logprobs"] - obj.is_all_greedy = aux_data["is_all_greedy"] - obj.need_min_p_sampling = aux_data["need_min_p_sampling"] return obj @@ -84,6 +88,9 @@ def from_model_worker_batch( pad_size: int = 0, mesh: Mesh = None, ) -> SamplingMetadata: + sharding = ( + NamedSharding(mesh, PartitionSpec()) if jax.process_count() == 1 else None + ) padded_temperatures = np.concat( [ batch.sampling_info.temperatures, @@ -101,7 +108,7 @@ def from_model_worker_batch( padded_top_ks = np.concat( [ batch.sampling_info.top_ks, - np.array([-1] * pad_size, dtype=batch.sampling_info.top_ks.dtype), + np.array([1] * pad_size, dtype=batch.sampling_info.top_ks.dtype), ] ) padded_min_ps = np.concat( @@ -110,15 +117,26 @@ def from_model_worker_batch( np.array([0.0] * pad_size, dtype=batch.sampling_info.min_ps.dtype), ] ) + if batch.sampling_info.sampling_seeds is not None: + padded_sampling_seeds = np.concat( + [ + batch.sampling_info.sampling_seeds, + np.array( + [DEFAULT_SAMPLING_SEED] * pad_size, + dtype=batch.sampling_info.sampling_seeds.dtype, + ), + ] + ) + sampling_seeds_device = device_array( + padded_sampling_seeds, sharding=sharding + ) + else: + sampling_seeds_device = None (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = ( device_array( (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), - sharding=( - NamedSharding(mesh, PartitionSpec()) - if jax.process_count() == 1 - else None - ), + sharding=sharding, ) ) @@ -130,6 +148,7 @@ def from_model_worker_batch( top_ps=top_ps_device, top_ks=top_ks_device, min_ps=min_ps_device, + sampling_seeds=sampling_seeds_device, is_all_greedy=batch.sampling_info.is_all_greedy, need_min_p_sampling=batch.sampling_info.need_min_p_sampling, ) @@ -162,12 +181,26 @@ class SamplingBatchInfo: # An event used for overlap schedule sampling_info_done: Optional[threading.Event] = None + sampling_seeds: Optional[np.ndarray] = None + + @classmethod + def _get_global_server_args_dict(cls): + from sgl_jax.srt.managers.schedule_batch import global_server_args_dict + + return global_server_args_dict + @classmethod def generate_for_precompile(cls, bs: int): - temperatures = np.array([1.0 for _ in range(bs)], dtype=np.float32) - top_ps = np.array([1.0 for _ in range(bs)], dtype=np.float32) - top_ks = np.array([-1 for _ in range(bs)], dtype=np.int32) - min_ps = np.array([0.0 for _ in range(bs)], dtype=np.float32) + temperatures = np.array([0.6 for _ in range(bs)], dtype=np.float32) + top_ps = np.array([0.9 for _ in range(bs)], dtype=np.float32) + top_ks = np.array([30 for _ in range(bs)], dtype=np.int32) + min_ps = np.array([0.6 for _ in range(bs)], dtype=np.float32) + if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING"): + sampling_seeds = np.array( + [DEFAULT_SAMPLING_SEED for _ in range(bs)], dtype=np.int32 + ) + else: + sampling_seeds = None ret = cls( temperatures=temperatures, @@ -176,14 +209,17 @@ def generate_for_precompile(cls, bs: int): min_ps=min_ps, is_all_greedy=True, need_top_p_sampling=False, - need_top_k_sampling=True, - need_min_p_sampling=False, + need_top_k_sampling=False, + need_min_p_sampling=True, sampling_info_done=None, + sampling_seeds=sampling_seeds, ) return ret @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + global_server_args_dict = cls._get_global_server_args_dict() + enable_deterministic = global_server_args_dict["enable_deterministic_sampling"] reqs = batch.reqs temperatures = np.array( [r.sampling_params.temperature for r in reqs], @@ -193,6 +229,15 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): top_ks = np.array([r.sampling_params.top_k for r in reqs], dtype=np.int32) min_ps = np.array([r.sampling_params.min_p for r in reqs], dtype=np.float32) + sampling_seeds = ( + np.array( + [r.sampling_params.sampling_seed for r in reqs], + dtype=np.int32, + ) + if enable_deterministic + else None + ) + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -202,6 +247,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs), need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), + sampling_seeds=sampling_seeds, ) return ret @@ -217,9 +263,11 @@ def filter_batch(self, keep_indices: np.ndarray): "top_ps", "top_ks", "min_ps", + "sampling_seeds", ]: value = getattr(self, item, None) - setattr(self, item, value[keep_indices]) + if value is not None: + setattr(self, item, value[keep_indices]) def merge_batch(self, other: "SamplingBatchInfo", mesh: Mesh): # Note: because the __len()__ operator is defined on the temperatures tensor, @@ -230,10 +278,12 @@ def merge_batch(self, other: "SamplingBatchInfo", mesh: Mesh): "top_ps", "top_ks", "min_ps", + "sampling_seeds", ]: self_val = getattr(self, item, None) other_val = getattr(other, item, None) - setattr(self, item, np.concat([self_val, other_val])) + if self_val is not None and other_val is not None: + setattr(self, item, np.concat([self_val, other_val])) self.is_all_greedy &= other.is_all_greedy self.need_top_p_sampling |= other.need_top_p_sampling diff --git a/python/sgl_jax/srt/sampling/sampling_params.py b/python/sgl_jax/srt/sampling/sampling_params.py index 5ace12b7..c2e5754c 100644 --- a/python/sgl_jax/srt/sampling/sampling_params.py +++ b/python/sgl_jax/srt/sampling/sampling_params.py @@ -2,8 +2,11 @@ from typing import Dict, List, Optional, Union +from sgl_jax.srt.utils import get_bool_env_var + _SAMPLING_EPS = 1e-6 TOP_K_ALL = 1 << 30 +DEFAULT_SAMPLING_SEED = 42 class SamplingParams: @@ -39,6 +42,7 @@ def __init__( no_stop_trim: bool = False, stream_interval: Optional[int] = None, logit_bias: Optional[Dict[str, float]] = None, + sampling_seed: Optional[int] = None, ) -> None: self.max_new_tokens = max_new_tokens self.stop_strs = stop @@ -65,6 +69,18 @@ def __init__( self.no_stop_trim = no_stop_trim self.stream_interval = stream_interval self.logit_bias = logit_bias + # Used for deterministic sampling + if ( + get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING") + and sampling_seed is None + ): + # If deterministic sampling is enabled and sampling_seed is not set, use the default seed + sampling_seed = DEFAULT_SAMPLING_SEED + self.sampling_seed = sampling_seed + + print( + f"[sampling_params__init__] {get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING")}" + ) # Process some special cases if 0 <= self.temperature < _SAMPLING_EPS: diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 478cc8af..8c666b16 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -129,6 +129,9 @@ class ServerArgs: disable_jax_precompile: bool = False + # For deterministic sampling + enable_deterministic_sampling: bool = False + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -182,6 +185,10 @@ def __post_init__(self): ) self.chunked_prefill_size = -1 + os.environ["SGLANG_ENABLE_DETERMINISTIC_SAMPLING"] = ( + "1" if self.enable_deterministic_sampling else "0" + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and tokenizer @@ -750,6 +757,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Choose the kernels for attention layers.", ) + # For deterministic sampling + parser.add_argument( + "--enable-deterministic-sampling", + action="store_true", + help="Enable deterministic sampling", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size diff --git a/python/sgl_jax/test/test_sampler.py b/python/sgl_jax/test/test_sampler.py new file mode 100644 index 00000000..57ba42ef --- /dev/null +++ b/python/sgl_jax/test/test_sampler.py @@ -0,0 +1,109 @@ +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +from sgl_jax.srt.layers.sampler import multinomial_with_seed + + +class TestMultinomialWithSeed(unittest.TestCase): + + def test_deterministic_sampling_with_same_seed(self): + """Test that same (inputs, seed) pair always yields the same sample.""" + # Setup test data + batch_size = 4 + vocab_size = 10 + + # Create logits that simulate different temperature scenarios + flatter_distribution = jnp.array( + [ + [1.0, 1.1, 0.9, 1.2, 0.8, 1.3, 0.7, 1.4, 0.6, 1.5], + [2.0, 2.1, 1.9, 2.2, 1.8, 2.3, 1.7, 2.4, 1.6, 2.5], + [0.5, 0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1, 1.0], + [3.0, 3.1, 2.9, 3.2, 2.8, 3.3, 2.7, 3.4, 2.6, 3.5], + ], + dtype=jnp.bfloat16, + ) + + flatter_distribution_processed = jax.nn.softmax(flatter_distribution, axis=-1) + + shaper_distribution = jnp.array( + [ + [1.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [2.0, 2.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0], + [0.5, 0.5, 0.5, 7.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [3.0, 3.0, 3.0, 3.0, 9.0, 3.0, 3.0, 3.0, 3.0, 3.0], + ], + dtype=jnp.bfloat16, + ) + + shaper_distribution_processed = jax.nn.softmax(shaper_distribution, axis=-1) + + seeds = jnp.array([12345, 67890, 54321, 98765]) + positions = jnp.array([0, 1, 2, 3]) + + test_cases = [ + ("flatter_distribution", flatter_distribution_processed), + ("shaper_distribution", shaper_distribution_processed), + ] + + for test_name, inputs in test_cases: + with self.subTest(test_name=test_name): + # Sample multiple times with the same inputs and seeds + samples = [] + for _ in range(10): # Run 10 times + sample = multinomial_with_seed((inputs, seeds, positions, None)) + samples.append(sample) + + # All samples should be identical + first_sample = samples[0] + for i, sample in enumerate(samples[1:], 1): + np.testing.assert_array_equal( + first_sample, + sample, + f"Sample {i} differs from first sample for {test_name}", + ) + + def test_different_seeds_produce_different_samples(self): + """Test that different seeds produce different samples (with high probability).""" + batch_size = 1 + vocab_size = 10 + + inputs = jnp.ones((batch_size, vocab_size), dtype=jnp.bfloat16) * 0.1 + inputs = jax.nn.softmax(inputs, axis=-1) + positions = jnp.array([0]) + + seeds = [jnp.array([1]), jnp.array([2]), jnp.array([12345]), jnp.array([98765])] + + samples = [] + for seed in seeds: + sample = multinomial_with_seed((inputs, seed, positions, None)) + samples.append(sample) + + original_len = len(samples) + unique_samples = set(tuple(sample.flatten().tolist()) for sample in samples) + self.assertEqual(original_len, len(unique_samples)) + + def test_output_shape_and_range(self): + """Test that output has correct shape and values are in valid range.""" + batch_size = 3 + vocab_size = 7 + + inputs = jnp.ones((batch_size, vocab_size), dtype=jnp.bfloat16) + inputs = jax.nn.softmax(inputs, axis=-1) + seeds = jnp.array([1, 2, 3]) + positions = jnp.array([0, 1, 2]) + + sample = multinomial_with_seed((inputs, seeds, positions, None)) + + expected_shape = (batch_size, 1) # Function returns keepdims=True + self.assertEqual(sample.shape, expected_shape) + + self.assertTrue(jnp.all(sample >= 0)) + self.assertTrue(jnp.all(sample < vocab_size)) + self.assertTrue(sample.dtype in [jnp.int32, jnp.int64]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_features.py b/test/srt/test_features.py index 5e3ea588..0f0e250e 100644 --- a/test/srt/test_features.py +++ b/test/srt/test_features.py @@ -124,18 +124,32 @@ def test_abort_all(self): future.result()["meta_info"]["finish_reason"]["type"], "abort" ) - def test_cache_miss(self): + def test_cache_miss_prefill(self): args = SimpleNamespace( base_url=self.base_url, text="the capital of France is", temperature=0, - max_new_tokens=6, + max_new_tokens=1, + ) + + resp = run_curl(args) + + if "cache_miss_count" not in resp["meta_info"]: + raise "[prefill] cache_miss_count is missed in response" + self.assertEqual(resp["meta_info"]["cache_miss_count"], 0) + + def test_cache_miss_decode(self): + args = SimpleNamespace( + base_url=self.base_url, + text="the capital of France is", + temperature=0, + max_new_tokens=2, ) resp = run_curl(args) if "cache_miss_count" not in resp["meta_info"]: - raise "cache_miss_count is missed in response" + raise "[prefill] cache_miss_count is missed in response" self.assertEqual(resp["meta_info"]["cache_miss_count"], 0) def test_logprobs(self): diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 126ed49e..411fe00c 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -33,6 +33,7 @@ def setUpClass(cls): precompile_token_paddings=[1024], page_size=64, log_requests=False, + enable_deterministic_sampling=True, ) cls.tokenizer = get_tokenizer(cls.model_path) @@ -110,3 +111,134 @@ def test_2_engine_prompt_ids_with_sample_n_output_ids(self): True, ) self.assertEqual(decoded_output, item["text"]) + + def test_3_engine_sampling_temperature_top_p_top_k_min_p(self): + input_strings = ["the capital of France is"] + + def get_sampling_params(max_new_tokens: int = 1): + sampling_params = TestSRTEngine.engine.get_default_sampling_params() + sampling_params.max_new_tokens = max_new_tokens + sampling_params.n = max_new_tokens + sampling_params.temperature = 0 + sampling_params.stop_token_ids = [TestSRTEngine.tokenizer.eos_token_id] + sampling_params.skip_special_tokens = True + return sampling_params + + def update_sampling_params( + params: SamplingParams, + temperature: float = None, + top_p: float = None, + top_k: int = None, + min_p: float = None, + sampling_seed: int = None, + ): + if temperature is not None: + params.temperature = temperature + if top_p is not None: + params.top_p = top_p + if top_k is not None: + params.top_k = top_k + if min_p is not None: + params.min_p = min_p + if sampling_seed is not None: + params.sampling_seed = sampling_seed + + cases = { + "[greedy] temperature[0.0]_top_p[1.0]_top_k[-1]_min_p[0.0]": ( + 0.0, + 1.0, + -1, + 0.0, + ), + "[greedy] temperature[0.5]_top_p[1.0]_top_k[1]_min_p[0.0]": ( + 0.5, + 1.0, + 1, + 0.0, + ), + "[not_greedy_top_p_0.9] temperature[0.6]_top_p[0.9]_top_k[-1]_min_p[0.0]": ( + 0.6, + 0.9, + -1, + 0.0, + ), + "[not_greedy_top_k_10] temperature[0.6]_top_p[1.0]_top_k[10]_min_p[0.0]": ( + 0.6, + 1.0, + 10, + 0.0, + ), + "[not_greedy_min_p_0.5] temperature[0.6]_top_p[1.0]_top_k[-1]_min_p[0.5]": ( + 0.6, + 1.0, + -1, + 0.5, + ), # need_min_p_sampling + "[not_greedy_sampling_seed_36] temperature[0.6]_top_p[1.0]_top_k[-1]_min_p[0.0]_sampling_seed[36]": ( + 0.5, + 1.0, + -1, + 0.0, + 36, + ), + "[not_greedy_tempeture_top_p_top_min_p_sampling_seed] temperature[0.5]_top_p[0.9]_top_k[10]_min_p[0.5]_sampling_seed[40]": ( + 0.5, + 0.9, + 10, + 0.5, + 40, + ), + } + + prompt_ids_list = [self.tokenize(x) for x in input_strings] + + # prefill + sampling_params_prefill = get_sampling_params(1) + for case_name, args in cases.items(): + print(f"[prefill, {case_name}] begins to run") + update_sampling_params(sampling_params_prefill, *args) + sampling_params_dict = sampling_params_prefill.convert_to_dict() + outputs = TestSRTEngine.engine.generate( + input_ids=prompt_ids_list, + sampling_params=sampling_params_dict, + ) + self.assertEqual(int(outputs[0]["meta_info"]["cache_miss_count"]), 0) + + # decode + sampling_params_decode = get_sampling_params(2) + for case_name, args in cases.items(): + print(f"[decode, {case_name}] begins to run") + update_sampling_params(sampling_params_decode, *args) + sampling_params_dict = sampling_params_decode.convert_to_dict() + outputs = TestSRTEngine.engine.generate( + input_ids=prompt_ids_list, + sampling_params=sampling_params_dict, + ) + self.assertEqual(int(outputs[0]["meta_info"]["cache_miss_count"]), 0) + + def test_4_engine_prompt_ids_with_sample_n_output_ids(self): + input_strings = ["the capital of China is"] + + sampling_params = TestSRTEngine.engine.get_default_sampling_params() + sampling_params.max_new_tokens = 10 + sampling_params.n = 20 + sampling_params.temperature = 0 + sampling_params.stop_token_ids = [TestSRTEngine.tokenizer.eos_token_id] + sampling_params.skip_special_tokens = True + # sampling_params.sampling_seed= 30 + + sampling_params_dict = sampling_params.convert_to_dict() + + prompt_ids_list = [self.tokenize(x) for x in input_strings] + outputs = TestSRTEngine.engine.generate( + input_ids=prompt_ids_list, + sampling_params=[sampling_params_dict] * 2, + ) + + # self.assertEqual(len(outputs), 4) + # for item in outputs: + # decoded_output = TestSRTEngine.tokenizer.decode( + # item["output_ids"], + # True, + # ) + # self.assertEqual(decoded_output, item["text"])