@@ -18,14 +18,14 @@ def __init__(self, rngs: nnx.Rngs = None):
18
18
19
19
def _greedy_sampling (self , operands ):
20
20
"""Greedy sampling branch"""
21
- logits , _ , _ = operands
21
+ logits , _ , _ , _ = operands
22
22
batch_next_token_ids = jnp .argmax (logits , - 1 ).flatten ()
23
23
logprobs = jax .nn .log_softmax (logits , axis = - 1 )
24
24
return batch_next_token_ids , logprobs
25
25
26
26
def _regular_sampling (self , operands ):
27
27
"""Regular sampling branch"""
28
- logits , sampling_metadata , rng = operands
28
+ logits , sampling_metadata , positions , rng = operands
29
29
30
30
# Post process logits
31
31
processed_logits = jnp .divide (logits , sampling_metadata .temperatures ).astype (
@@ -39,6 +39,8 @@ def _regular_sampling(self, operands):
39
39
sampling_metadata .top_ks ,
40
40
sampling_metadata .top_ps ,
41
41
sampling_metadata .min_ps ,
42
+ positions ,
43
+ sampling_metadata .sampling_seeds ,
42
44
sampling_metadata .need_min_p_sampling ,
43
45
rng ,
44
46
)
@@ -80,18 +82,14 @@ def __call__(
80
82
self ,
81
83
logits_output : LogitsProcessorOutput ,
82
84
sampling_metadata : SamplingMetadata ,
85
+ positions : jax .Array ,
83
86
):
84
87
"""Run a sampler & compute logprobs and update logits_output accordingly.
85
88
86
89
Args:
87
90
logits_output: The logits from the model forward
88
- sampling_info: Metadata for sampling
89
- return_logprob: If set, store the output logprob information to
90
- logits_output
91
- top_logprobs_nums: Number of top lobprobs per sequence in a batch
92
- batch_next_token_ids: next token IDs. If set, skip sampling and only
93
- compute output logprobs It is used for speculative decoding which
94
- performs sampling in draft workers.
91
+ sampling_metadata: Metadata for sampling
92
+ positions: The positions of the tokens in the sequence.
95
93
"""
96
94
97
95
logits = jnp .reshape (
@@ -101,7 +99,7 @@ def __call__(
101
99
102
100
_ , rng = jax .random .split (self .rngs .params ())
103
101
104
- operands = (logits , sampling_metadata , rng )
102
+ operands = (logits , sampling_metadata , positions , rng )
105
103
batch_next_token_ids , logprobs = lax .cond (
106
104
sampling_metadata .is_all_greedy ,
107
105
self ._greedy_sampling ,
@@ -158,19 +156,75 @@ def top_k_top_p_min_p_sampling_from_probs_jax(
158
156
top_ks : jax .Array ,
159
157
top_ps : jax .Array ,
160
158
min_ps : jax .Array ,
161
- need_min_p_sampling : bool ,
162
- rng : nnx .Rngs ,
159
+ positions : jax .Array ,
160
+ sampling_seeds : jax .Array = None ,
161
+ need_min_p_sampling : bool = False ,
162
+ rng : nnx .Rngs = None ,
163
163
):
164
164
"""A top-k, top-p and min-p sampling implementation with native jax operations."""
165
165
probs_sort , probs_idx = _sample_part_a (
166
166
probs , top_ks , top_ps , min_ps , need_min_p_sampling
167
167
)
168
168
169
- sampled_index = random .categorical (rng , jnp .log (probs_sort )).reshape (- 1 , 1 )
169
+ multinomial_operands = (probs_sort , sampling_seeds , positions , rng )
170
+ sampled_index = lax .cond (
171
+ sampling_seeds is not None ,
172
+ multinomial_with_seed ,
173
+ multinomial ,
174
+ multinomial_operands ,
175
+ )
170
176
171
177
return _sample_part_b (probs_idx , sampled_index )
172
178
173
179
180
+ def multinomial (
181
+ operands ,
182
+ ) -> jax .Array :
183
+ inputs , _ , _ , rng = operands
184
+ return random .categorical (rng , jnp .log (inputs )).reshape (- 1 , 1 )
185
+
186
+
187
+ def multinomial_with_seed (
188
+ operands ,
189
+ ) -> jax .Array :
190
+ """
191
+ Note:
192
+ 1. This implementation is copied from https://github.com/sgl-project/sglang/blob/e2ac7888b8cb1fd6c33a7ec58d27a5f5b5b24e0c/python/sglang/srt/layers/sampler.py#L268.
193
+ 2. Based on last response in issue, the fixed four big prime numbers can be set freely. 8589934591 is out of uin32, so I replace it with 805306457.
194
+ - issue: https://github.com/sgl-project/sglang/issues/10938
195
+
196
+ Samples n elements from an input array `inputs` of shape (n, m) using
197
+ a unique random seed for each row.
198
+
199
+ Args:
200
+ inputs: A float array of shape (n, m) representing n categorical
201
+ distributions with m categories each. The values are treated
202
+ as weights and do not need to sum to 1.
203
+ seed: An integer array of shape (n,) containing the random seed
204
+ for each corresponding row in `inputs`.
205
+ positions: The positions of the tokens in the sequence.
206
+
207
+ Returns:
208
+ A array of shape (n,) where the i-th element is an index sampled
209
+ from the distribution in `inputs[i]` using `seed[i]`.
210
+ """
211
+ inputs , seed , positions , _ = operands
212
+ if seed is None :
213
+ # note: this codes is used to keep compatible with lax.cond
214
+ return multinomial (operands )
215
+ n , m = inputs .shape
216
+ step_seed = seed * 19349663 ^ positions * 73856093
217
+ seed_expanded = step_seed [:, None ]
218
+ col_indices = jnp .arange (m )[None , :]
219
+ hashed = seed_expanded * 805306457 ^ col_indices * 479001599
220
+ uniform_samples = (hashed % (2 ** 24 )).astype (jnp .float32 ) / (2 ** 24 )
221
+ epsilon = 1e-9
222
+ gumbel_noise = - jnp .log (- jnp .log (uniform_samples + epsilon ) + epsilon )
223
+ log_probs = jnp .log (inputs + epsilon )
224
+ perturbed_log_probs = log_probs + gumbel_noise
225
+ return jnp .argmax (perturbed_log_probs , axis = 1 , keepdims = True )
226
+
227
+
174
228
def _apply_min_p_filter (operands ):
175
229
"""Apply min_p filtering when need_min_p_sampling=True"""
176
230
probs_sort , min_ps = operands
0 commit comments