Skip to content

Commit 41ce0a8

Browse files
authored
add multinomial_with_seed for sampler and test_sampler.py (#12)
1 parent 7424998 commit 41ce0a8

File tree

12 files changed

+482
-43
lines changed

12 files changed

+482
-43
lines changed

python/sgl_jax/srt/layers/attention/flashattention_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
101101
cu_q_lens = np.concatenate(
102102
[
103103
np.array([0], dtype=np.int32),
104-
np.cumsum(batch.extend_seq_lens),
104+
np.cumsum(batch.extend_seq_lens, dtype=np.int32),
105105
]
106106
)
107107
elif batch.forward_mode == ForwardMode.DECODE:
108-
cu_q_lens = jnp.concatenate(
108+
cu_q_lens = np.concatenate(
109109
[
110-
np.array([0], dtype=jnp.int32),
111-
np.cumsum(jnp.ones(len(batch.seq_lens), dtype=np.int32)),
110+
np.array([0], dtype=np.int32),
111+
np.cumsum(np.ones(len(batch.seq_lens), dtype=np.int32)),
112112
]
113113
)
114114
else:

python/sgl_jax/srt/layers/sampler.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def __init__(self, rngs: nnx.Rngs = None):
1818

1919
def _greedy_sampling(self, operands):
2020
"""Greedy sampling branch"""
21-
logits, _, _ = operands
21+
logits, _, _, _ = operands
2222
batch_next_token_ids = jnp.argmax(logits, -1).flatten()
2323
logprobs = jax.nn.log_softmax(logits, axis=-1)
2424
return batch_next_token_ids, logprobs
2525

2626
def _regular_sampling(self, operands):
2727
"""Regular sampling branch"""
28-
logits, sampling_metadata, rng = operands
28+
logits, sampling_metadata, positions, rng = operands
2929

3030
# Post process logits
3131
processed_logits = jnp.divide(logits, sampling_metadata.temperatures).astype(
@@ -39,6 +39,8 @@ def _regular_sampling(self, operands):
3939
sampling_metadata.top_ks,
4040
sampling_metadata.top_ps,
4141
sampling_metadata.min_ps,
42+
positions,
43+
sampling_metadata.sampling_seeds,
4244
sampling_metadata.need_min_p_sampling,
4345
rng,
4446
)
@@ -80,18 +82,14 @@ def __call__(
8082
self,
8183
logits_output: LogitsProcessorOutput,
8284
sampling_metadata: SamplingMetadata,
85+
positions: jax.Array,
8386
):
8487
"""Run a sampler & compute logprobs and update logits_output accordingly.
8588
8689
Args:
8790
logits_output: The logits from the model forward
88-
sampling_info: Metadata for sampling
89-
return_logprob: If set, store the output logprob information to
90-
logits_output
91-
top_logprobs_nums: Number of top lobprobs per sequence in a batch
92-
batch_next_token_ids: next token IDs. If set, skip sampling and only
93-
compute output logprobs It is used for speculative decoding which
94-
performs sampling in draft workers.
91+
sampling_metadata: Metadata for sampling
92+
positions: The positions of the tokens in the sequence.
9593
"""
9694

9795
logits = jnp.reshape(
@@ -101,7 +99,7 @@ def __call__(
10199

102100
_, rng = jax.random.split(self.rngs.params())
103101

104-
operands = (logits, sampling_metadata, rng)
102+
operands = (logits, sampling_metadata, positions, rng)
105103
batch_next_token_ids, logprobs = lax.cond(
106104
sampling_metadata.is_all_greedy,
107105
self._greedy_sampling,
@@ -158,19 +156,75 @@ def top_k_top_p_min_p_sampling_from_probs_jax(
158156
top_ks: jax.Array,
159157
top_ps: jax.Array,
160158
min_ps: jax.Array,
161-
need_min_p_sampling: bool,
162-
rng: nnx.Rngs,
159+
positions: jax.Array,
160+
sampling_seeds: jax.Array = None,
161+
need_min_p_sampling: bool = False,
162+
rng: nnx.Rngs = None,
163163
):
164164
"""A top-k, top-p and min-p sampling implementation with native jax operations."""
165165
probs_sort, probs_idx = _sample_part_a(
166166
probs, top_ks, top_ps, min_ps, need_min_p_sampling
167167
)
168168

169-
sampled_index = random.categorical(rng, jnp.log(probs_sort)).reshape(-1, 1)
169+
multinomial_operands = (probs_sort, sampling_seeds, positions, rng)
170+
sampled_index = lax.cond(
171+
sampling_seeds is not None,
172+
multinomial_with_seed,
173+
multinomial,
174+
multinomial_operands,
175+
)
170176

171177
return _sample_part_b(probs_idx, sampled_index)
172178

173179

180+
def multinomial(
181+
operands,
182+
) -> jax.Array:
183+
inputs, _, _, rng = operands
184+
return random.categorical(rng, jnp.log(inputs)).reshape(-1, 1)
185+
186+
187+
def multinomial_with_seed(
188+
operands,
189+
) -> jax.Array:
190+
"""
191+
Note:
192+
1. This implementation is copied from https://github.com/sgl-project/sglang/blob/e2ac7888b8cb1fd6c33a7ec58d27a5f5b5b24e0c/python/sglang/srt/layers/sampler.py#L268.
193+
2. Based on last response in issue, the fixed four big prime numbers can be set freely. 8589934591 is out of uin32, so I replace it with 805306457.
194+
- issue: https://github.com/sgl-project/sglang/issues/10938
195+
196+
Samples n elements from an input array `inputs` of shape (n, m) using
197+
a unique random seed for each row.
198+
199+
Args:
200+
inputs: A float array of shape (n, m) representing n categorical
201+
distributions with m categories each. The values are treated
202+
as weights and do not need to sum to 1.
203+
seed: An integer array of shape (n,) containing the random seed
204+
for each corresponding row in `inputs`.
205+
positions: The positions of the tokens in the sequence.
206+
207+
Returns:
208+
A array of shape (n,) where the i-th element is an index sampled
209+
from the distribution in `inputs[i]` using `seed[i]`.
210+
"""
211+
inputs, seed, positions, _ = operands
212+
if seed is None:
213+
# note: this codes is used to keep compatible with lax.cond
214+
return multinomial(operands)
215+
n, m = inputs.shape
216+
step_seed = seed * 19349663 ^ positions * 73856093
217+
seed_expanded = step_seed[:, None]
218+
col_indices = jnp.arange(m)[None, :]
219+
hashed = seed_expanded * 805306457 ^ col_indices * 479001599
220+
uniform_samples = (hashed % (2**24)).astype(jnp.float32) / (2**24)
221+
epsilon = 1e-9
222+
gumbel_noise = -jnp.log(-jnp.log(uniform_samples + epsilon) + epsilon)
223+
log_probs = jnp.log(inputs + epsilon)
224+
perturbed_log_probs = log_probs + gumbel_noise
225+
return jnp.argmax(perturbed_log_probs, axis=1, keepdims=True)
226+
227+
174228
def _apply_min_p_filter(operands):
175229
"""Apply min_p filtering when need_min_p_sampling=True"""
176230
probs_sort, min_ps = operands

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
GLOBAL_SERVER_ARGS_KEYS = [
4848
"device",
4949
"disable_radix_cache",
50+
"enable_deterministic_sampling",
5051
]
5152

5253
PADDING_BUCKETS = [1 << i for i in range(6, 21)]

python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@ def process_batch_result_prefill(
142142
logprob_pt += num_input_logprobs
143143

144144
batch.cache_miss_count = cache_miss_count
145+
146+
if batch.cache_miss_count > 0:
147+
logger.info(
148+
f"Prefill batch. #bid: {result.bid}, #cache_miss: {cache_miss_count}"
149+
)
150+
145151
self.stream_output(
146152
batch.reqs, batch.return_logprob, skip_stream_req, cache_miss_count
147153
)

python/sgl_jax/srt/managers/tp_worker.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from flax import nnx
1313
from jax.experimental.multihost_utils import broadcast_one_to_all
14+
from jax.sharding import NamedSharding, PartitionSpec
1415
from tqdm import tqdm
1516

1617
from sgl_jax.srt.configs.model_config import ModelConfig
@@ -33,6 +34,7 @@
3334
PRECOMPILE_DEFAULT_BS_PADDINGS,
3435
PRECOMPILE_DEFAULT_TOKEN_PADDINGS,
3536
)
37+
from sgl_jax.srt.utils.jax_utils import device_array
3638

3739
logger = logging.getLogger(__name__)
3840

@@ -212,6 +214,9 @@ def precompile_extend(self, future_token_ids_map=None):
212214
ForwardMode.EXTEND,
213215
self.precompile_cache_loc_paddings[-1],
214216
)
217+
sampling_metadata = SamplingMetadata.from_model_worker_batch(
218+
model_worker_batch, 0, self.mesh
219+
)
215220
model_worker_batch.forward_batch = ForwardBatch.init_new(
216221
model_worker_batch, self.model_runner
217222
)
@@ -222,8 +227,10 @@ def precompile_extend(self, future_token_ids_map=None):
222227
future_token_ids_map,
223228
)
224229
)
225-
self.forward_batch_generation(model_worker_batch, None, True)
226230

231+
self.forward_batch_generation(
232+
model_worker_batch, None, False, sampling_metadata
233+
)
227234
end_time = time.perf_counter()
228235
logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time)
229236

@@ -266,6 +273,13 @@ def precompile_decode(self, future_token_ids_map=None):
266273
end_time = time.perf_counter()
267274
logger.info("[DECODE] Precompile finished in %.0f secs", end_time - start_time)
268275

276+
def set_forward_metadata(self, model_worker_batch: ModelWorkerBatch):
277+
self.model_runner.attn_backend.forward_metadata = (
278+
self.worker.model_runner.attn_backend.get_forward_metadata(
279+
model_worker_batch
280+
)
281+
)
282+
269283
def get_max_padded_size(self):
270284
"""Calculate the max padded batch size and token nums.
271285
@@ -393,6 +407,21 @@ def forward_batch_generation(
393407
)
394408

395409
self.model_runner.attn_backend.forward_metadata = forward_metadata
410+
# note: put positions on devices again because the forward_batch has been donated
411+
if not skip_sample:
412+
positions = (
413+
model_worker_batch.positions
414+
if model_worker_batch.forward_mode.is_decode()
415+
else model_worker_batch.seq_lens - 1
416+
)
417+
positions_device = device_array(
418+
positions,
419+
sharding=(
420+
NamedSharding(self.model_runner.mesh, PartitionSpec())
421+
if jax.process_count() == 1
422+
else None
423+
),
424+
)
396425
logits_output, cache_miss_count = self.model_runner.forward(
397426
forward_batch,
398427
logits_metadata=LogitsMetadata.from_model_worker_batch(
@@ -411,7 +440,9 @@ def forward_batch_generation(
411440

412441
with jtu.count_pjit_cpp_cache_miss() as count:
413442
next_token_ids_device = self.model_runner.sample(
414-
logits_output, sampling_metadata
443+
logits_output,
444+
sampling_metadata,
445+
positions_device,
415446
)
416447
sample_cache_miss_count = count()
417448

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from sgl_jax.srt.configs.model_config import AttentionArch, MockModelConfig, ModelConfig
1818
from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput
1919
from sgl_jax.srt.layers.sampler import Sampler
20-
from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch
20+
from sgl_jax.srt.managers.schedule_batch import (
21+
GLOBAL_SERVER_ARGS_KEYS,
22+
ModelWorkerBatch,
23+
global_server_args_dict,
24+
)
2125
from sgl_jax.srt.mem_cache.allocator import (
2226
BaseTokenToKVPoolAllocator,
2327
PagedTokenToKVPoolAllocator,
@@ -83,6 +87,12 @@ def __init__(
8387
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
8488

8589
self.forward_pass_id = 0
90+
91+
# Global vars
92+
global_server_args_dict.update(
93+
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
94+
)
95+
8696
self.model_loader = JAXModelLoader(
8797
load_config=LoadConfig(
8898
load_format=LoadFormat.JAX, download_dir=server_args.download_dir
@@ -431,19 +441,21 @@ def sample(
431441
self,
432442
logits_output: LogitsProcessorOutput,
433443
sampling_metadata: SamplingMetadata,
444+
positions: jax.Array,
434445
) -> jax.Array:
435446
"""Sample and compute logprobs and update logits_output.
436447
437448
Args:
438449
logits_output: The logits output from the model forward
439450
forward_batch: The forward batch that generates logits_output
440-
451+
positions: The positions of the tokens in the sequence.
441452
Returns:
442453
A list of next_token_ids
443454
"""
444455
return self.jitted_sampler(
445456
logits_output,
446457
sampling_metadata,
458+
positions,
447459
)
448460

449461

0 commit comments

Comments
 (0)