Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/sgl_jax/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def get_forward_metadata(self, batch: ModelWorkerBatch):
cu_q_lens = np.concatenate(
[
np.array([0], dtype=np.int32),
np.cumsum(batch.extend_seq_lens),
np.cumsum(batch.extend_seq_lens, dtype=np.int32),
]
)
elif batch.forward_mode == ForwardMode.DECODE:
cu_q_lens = jnp.concatenate(
cu_q_lens = np.concatenate(
[
np.array([0], dtype=jnp.int32),
np.cumsum(jnp.ones(len(batch.seq_lens), dtype=np.int32)),
np.array([0], dtype=np.int32),
np.cumsum(np.ones(len(batch.seq_lens), dtype=np.int32)),
]
)
else:
Expand Down
80 changes: 67 additions & 13 deletions python/sgl_jax/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def __init__(self, rngs: nnx.Rngs = None):

def _greedy_sampling(self, operands):
"""Greedy sampling branch"""
logits, _, _ = operands
logits, _, _, _ = operands
batch_next_token_ids = jnp.argmax(logits, -1).flatten()
logprobs = jax.nn.log_softmax(logits, axis=-1)
return batch_next_token_ids, logprobs

def _regular_sampling(self, operands):
"""Regular sampling branch"""
logits, sampling_metadata, rng = operands
logits, sampling_metadata, positions, rng = operands

# Post process logits
processed_logits = jnp.divide(logits, sampling_metadata.temperatures).astype(
Expand All @@ -39,6 +39,8 @@ def _regular_sampling(self, operands):
sampling_metadata.top_ks,
sampling_metadata.top_ps,
sampling_metadata.min_ps,
positions,
sampling_metadata.sampling_seeds,
sampling_metadata.need_min_p_sampling,
rng,
)
Expand Down Expand Up @@ -80,18 +82,14 @@ def __call__(
self,
logits_output: LogitsProcessorOutput,
sampling_metadata: SamplingMetadata,
positions: jax.Array,
):
"""Run a sampler & compute logprobs and update logits_output accordingly.

Args:
logits_output: The logits from the model forward
sampling_info: Metadata for sampling
return_logprob: If set, store the output logprob information to
logits_output
top_logprobs_nums: Number of top lobprobs per sequence in a batch
batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which
performs sampling in draft workers.
sampling_metadata: Metadata for sampling
positions: The positions of the tokens in the sequence.
"""

logits = jnp.reshape(
Expand All @@ -101,7 +99,7 @@ def __call__(

_, rng = jax.random.split(self.rngs.params())

operands = (logits, sampling_metadata, rng)
operands = (logits, sampling_metadata, positions, rng)
batch_next_token_ids, logprobs = lax.cond(
sampling_metadata.is_all_greedy,
self._greedy_sampling,
Expand Down Expand Up @@ -158,19 +156,75 @@ def top_k_top_p_min_p_sampling_from_probs_jax(
top_ks: jax.Array,
top_ps: jax.Array,
min_ps: jax.Array,
need_min_p_sampling: bool,
rng: nnx.Rngs,
positions: jax.Array,
sampling_seeds: jax.Array = None,
need_min_p_sampling: bool = False,
rng: nnx.Rngs = None,
):
"""A top-k, top-p and min-p sampling implementation with native jax operations."""
probs_sort, probs_idx = _sample_part_a(
probs, top_ks, top_ps, min_ps, need_min_p_sampling
)

sampled_index = random.categorical(rng, jnp.log(probs_sort)).reshape(-1, 1)
multinomial_operands = (probs_sort, sampling_seeds, positions, rng)
sampled_index = lax.cond(
sampling_seeds is not None,
multinomial_with_seed,
multinomial,
multinomial_operands,
)

return _sample_part_b(probs_idx, sampled_index)


def multinomial(
operands,
) -> jax.Array:
inputs, _, _, rng = operands
return random.categorical(rng, jnp.log(inputs)).reshape(-1, 1)


def multinomial_with_seed(
operands,
) -> jax.Array:
"""
Note:
1. This implementation is copied from https://github.com/sgl-project/sglang/blob/e2ac7888b8cb1fd6c33a7ec58d27a5f5b5b24e0c/python/sglang/srt/layers/sampler.py#L268.
2. Based on last response in issue, the fixed four big prime numbers can be set freely. 8589934591 is out of uin32, so I replace it with 805306457.
- issue: https://github.com/sgl-project/sglang/issues/10938

Samples n elements from an input array `inputs` of shape (n, m) using
a unique random seed for each row.

Args:
inputs: A float array of shape (n, m) representing n categorical
distributions with m categories each. The values are treated
as weights and do not need to sum to 1.
seed: An integer array of shape (n,) containing the random seed
for each corresponding row in `inputs`.
positions: The positions of the tokens in the sequence.

Returns:
A array of shape (n,) where the i-th element is an index sampled
from the distribution in `inputs[i]` using `seed[i]`.
"""
inputs, seed, positions, _ = operands
if seed is None:
# note: this codes is used to keep compatible with lax.cond
return multinomial(operands)
n, m = inputs.shape
step_seed = seed * 19349663 ^ positions * 73856093
seed_expanded = step_seed[:, None]
col_indices = jnp.arange(m)[None, :]
hashed = seed_expanded * 805306457 ^ col_indices * 479001599
uniform_samples = (hashed % (2**24)).astype(jnp.float32) / (2**24)
epsilon = 1e-9
gumbel_noise = -jnp.log(-jnp.log(uniform_samples + epsilon) + epsilon)
log_probs = jnp.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise
return jnp.argmax(perturbed_log_probs, axis=1, keepdims=True)


def _apply_min_p_filter(operands):
"""Apply min_p filtering when need_min_p_sampling=True"""
probs_sort, min_ps = operands
Expand Down
1 change: 1 addition & 0 deletions python/sgl_jax/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
GLOBAL_SERVER_ARGS_KEYS = [
"device",
"disable_radix_cache",
"enable_deterministic_sampling",
]

PADDING_BUCKETS = [1 << i for i in range(6, 21)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ def process_batch_result_prefill(
logprob_pt += num_input_logprobs

batch.cache_miss_count = cache_miss_count

if batch.cache_miss_count > 0:
logger.info(
f"Prefill batch. #bid: {result.bid}, #cache_miss: {cache_miss_count}"
)

self.stream_output(
batch.reqs, batch.return_logprob, skip_stream_req, cache_miss_count
)
Expand Down
35 changes: 33 additions & 2 deletions python/sgl_jax/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from flax import nnx
from jax.experimental.multihost_utils import broadcast_one_to_all
from jax.sharding import NamedSharding, PartitionSpec
from tqdm import tqdm

from sgl_jax.srt.configs.model_config import ModelConfig
Expand All @@ -33,6 +34,7 @@
PRECOMPILE_DEFAULT_BS_PADDINGS,
PRECOMPILE_DEFAULT_TOKEN_PADDINGS,
)
from sgl_jax.srt.utils.jax_utils import device_array

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -212,6 +214,9 @@ def precompile_extend(self, future_token_ids_map=None):
ForwardMode.EXTEND,
self.precompile_cache_loc_paddings[-1],
)
sampling_metadata = SamplingMetadata.from_model_worker_batch(
model_worker_batch, 0, self.mesh
)
model_worker_batch.forward_batch = ForwardBatch.init_new(
model_worker_batch, self.model_runner
)
Expand All @@ -222,8 +227,10 @@ def precompile_extend(self, future_token_ids_map=None):
future_token_ids_map,
)
)
self.forward_batch_generation(model_worker_batch, None, True)

self.forward_batch_generation(
model_worker_batch, None, False, sampling_metadata
)
end_time = time.perf_counter()
logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time)

Expand Down Expand Up @@ -266,6 +273,13 @@ def precompile_decode(self, future_token_ids_map=None):
end_time = time.perf_counter()
logger.info("[DECODE] Precompile finished in %.0f secs", end_time - start_time)

def set_forward_metadata(self, model_worker_batch: ModelWorkerBatch):
self.model_runner.attn_backend.forward_metadata = (
self.worker.model_runner.attn_backend.get_forward_metadata(
model_worker_batch
)
)

def get_max_padded_size(self):
"""Calculate the max padded batch size and token nums.

Expand Down Expand Up @@ -393,6 +407,21 @@ def forward_batch_generation(
)

self.model_runner.attn_backend.forward_metadata = forward_metadata
# note: put positions on devices again because the forward_batch has been donated
if not skip_sample:
positions = (
model_worker_batch.positions
if model_worker_batch.forward_mode.is_decode()
else model_worker_batch.seq_lens - 1
)
positions_device = device_array(
positions,
sharding=(
NamedSharding(self.model_runner.mesh, PartitionSpec())
if jax.process_count() == 1
else None
),
)
logits_output, cache_miss_count = self.model_runner.forward(
forward_batch,
logits_metadata=LogitsMetadata.from_model_worker_batch(
Expand All @@ -411,7 +440,9 @@ def forward_batch_generation(

with jtu.count_pjit_cpp_cache_miss() as count:
next_token_ids_device = self.model_runner.sample(
logits_output, sampling_metadata
logits_output,
sampling_metadata,
positions_device,
)
sample_cache_miss_count = count()

Expand Down
16 changes: 14 additions & 2 deletions python/sgl_jax/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from sgl_jax.srt.configs.model_config import AttentionArch, MockModelConfig, ModelConfig
from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput
from sgl_jax.srt.layers.sampler import Sampler
from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch
from sgl_jax.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
ModelWorkerBatch,
global_server_args_dict,
)
from sgl_jax.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
Expand Down Expand Up @@ -83,6 +87,12 @@ def __init__(
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA

self.forward_pass_id = 0

# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
)

self.model_loader = JAXModelLoader(
load_config=LoadConfig(
load_format=LoadFormat.JAX, download_dir=server_args.download_dir
Expand Down Expand Up @@ -431,19 +441,21 @@ def sample(
self,
logits_output: LogitsProcessorOutput,
sampling_metadata: SamplingMetadata,
positions: jax.Array,
) -> jax.Array:
"""Sample and compute logprobs and update logits_output.

Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output

positions: The positions of the tokens in the sequence.
Returns:
A list of next_token_ids
"""
return self.jitted_sampler(
logits_output,
sampling_metadata,
positions,
)


Expand Down
Loading
Loading