diff --git a/benchmark/profile_pipeline_api.py b/benchmark/profile_pipeline_api.py index 9df88e6907..ebffdd317c 100644 --- a/benchmark/profile_pipeline_api.py +++ b/benchmark/profile_pipeline_api.py @@ -134,7 +134,7 @@ class Engine: def __init__(self, model_path: str, engine_config, csv: str): self.pipe = pipeline(model_path, backend_config=engine_config, log_level='ERROR') self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - + self.return_routed_experts = getattr(self.pipe.backend_config, 'enable_return_routed_experts', False) self.csv = csv def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output): @@ -146,6 +146,7 @@ def process_request(self, requests, profiler: Profiler, temperature, top_p, top_ top_k=top_k, ignore_eos=True, do_sample=False, + return_routed_experts=self.return_routed_experts, max_new_tokens=output_len) for _, _, output_len in requests ] @@ -254,6 +255,7 @@ def parse_args(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) + ArgumentHelper.enable_return_routed_experts(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) @@ -302,6 +304,7 @@ def main(): thread_safe=False, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, + enable_return_routed_experts=args.enable_return_routed_experts, ) engine = Engine(args.model_path, engine_config, csv=args.csv) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 8f67743951..ee46389c19 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -97,6 +97,7 @@ def add_parser_api_server(): ArgumentHelper.dllm_unmasking_strategy(pt_group) ArgumentHelper.dllm_denoising_steps(pt_group) ArgumentHelper.dllm_confidence_threshold(pt_group) + ArgumentHelper.enable_return_routed_experts(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) @@ -228,6 +229,7 @@ def api_server(args): dllm_unmasking_strategy=args.dllm_unmasking_strategy, dllm_denoising_steps=args.dllm_denoising_steps, dllm_confidence_threshold=args.dllm_confidence_threshold, + enable_return_routed_experts=args.enable_return_routed_experts, ) else: from lmdeploy.messages import TurbomindEngineConfig diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index a53a3cdc86..11df1218b7 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -667,6 +667,15 @@ def dllm_confidence_threshold(parser): default=0.85, help='The confidence threshold for dllm.') + @staticmethod + def enable_return_routed_experts(parser): + """Add argument return routed experts to parser.""" + + return parser.add_argument('--enable-return-routed-experts', + action='store_true', + default=False, + help='Whether to output routed expert ids for replay') + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 57725eb23f..6711b630df 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -117,6 +117,9 @@ class GenerationConfig: preserve_cache: bool = False migration_request: Optional[MigrationRequest] = None + # router replay + return_routed_experts: bool = False + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" @@ -376,6 +379,7 @@ class PytorchEngineConfig: hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False logprobs_mode: str = None + enable_return_routed_experts: bool = False # dllm dllm_block_length: int = None @@ -457,14 +461,18 @@ class Response: logits: torch.Tensor = None last_hidden_state: torch.Tensor = None index: int = 0 + routed_experts: Any = None def __repr__(self): logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}' hidden_state = ( 'last_hidden_state=None' if self.last_hidden_state is None else f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}') - s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n' - f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}') + routed_experts = 'routed_experts=None' if self.routed_experts is None else \ + f'routed_experts.shape={self.routed_experts.shape}' + + s = (f'text={self.text!r}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n' + f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}\n{routed_experts}') return s @@ -544,6 +552,7 @@ class EngineOutput: last_hidden_state: torch.Tensor = None cache_block_ids: Optional[List[int]] = None req_metrics: Optional[RequestMetrics] = None + routed_experts: torch.Tensor = None @dataclass diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 88a4f07098..1288c38ec9 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -91,16 +91,6 @@ def __init__( self.pool = pool self._graph: torch.cuda.CUDAGraph = None - def make_output_buffers(self, output): - """Make output buffers.""" - output_buffers = dict(logits=output) - return output_buffers - - def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]): - """Slice output.""" - num_tokens = inputs['input_ids'].size(-1) - return output_buffers['logits'][:, :num_tokens] - @record_function('capture_cudagraph') def capture(self, **kwargs): """Capture graph.""" @@ -113,7 +103,7 @@ def capture(self, **kwargs): # warmup warmup_output = self.model(**padded_kwargs) - warmup_buffers = self.make_output_buffers(warmup_output) + warmup_buffers = self.model.make_output_buffers(warmup_output) self._graph = torch.cuda.CUDAGraph() # unsafe kernel call in other thread might invalid the capture @@ -121,9 +111,9 @@ def capture(self, **kwargs): with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'): output = self.model(**padded_kwargs) - output_buffers = self.make_output_buffers(output) + output_buffers = self.model.make_output_buffers(output) self.meta.output_buffers = output_buffers - output = self.slice_output(warmup_buffers, kwargs) + output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs) return output @record_function('forward_cudagraph') @@ -134,9 +124,8 @@ def forward(self, **kwargs): context = self.ctx_mgr.current_context() self.model.update_context_cudagraph(self.meta, context) self._graph.replay() - output_buffers = self.meta.output_buffers - output = self.slice_output(output_buffers, kwargs) + output = self.model.get_outputs_cudagraph(output_buffers, **kwargs) return output def __del__(self): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index b21a9984ba..63bc36e96a 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -347,6 +347,7 @@ class MiscConfig: disable_vision_encoder: bool = False logprobs_mode: str = None dllm_config: DLLMConfig = None + enable_return_routed_experts: bool = False @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): @@ -356,12 +357,15 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): unmasking_strategy=dllm_unmasking_strategy, denoising_steps=engine_config.dllm_denoising_steps, confidence_threshold=engine_config.dllm_confidence_threshold) - misc_config = cls(custom_module_map=engine_config.custom_module_map, - empty_init=engine_config.empty_init, - prefill_interval=engine_config.prefill_interval, - model_format=engine_config.model_format, - hf_overrides=engine_config.hf_overrides, - disable_vision_encoder=engine_config.disable_vision_encoder, - logprobs_mode=engine_config.logprobs_mode, - dllm_config=dllm_config) + misc_config = cls( + custom_module_map=engine_config.custom_module_map, + empty_init=engine_config.empty_init, + prefill_interval=engine_config.prefill_interval, + model_format=engine_config.model_format, + hf_overrides=engine_config.hf_overrides, + disable_vision_encoder=engine_config.disable_vision_encoder, + logprobs_mode=engine_config.logprobs_mode, + dllm_config=dllm_config, + enable_return_routed_experts=engine_config.enable_return_routed_experts, + ) return misc_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index f583ed7224..0531bb9650 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -57,6 +57,9 @@ class InferOutput: # for logging req_metrics: RequestMetrics = None + # expert ids + routed_experts: torch.Tensor = None + def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): """Tensorlize block_offsets.""" @@ -876,13 +879,17 @@ def _make_infer_outputs( cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1]) req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) + routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None + if routed_experts is not None: + routed_experts = self.executor.serialize(routed_experts) out = InferOutput(session_id=session_id, resp=msg.resp, finish=finish, token_ids=token_ids, cache_block_ids=cache_block_ids, req_metrics=req_metrics, - logprobs=cur_logprobs) + logprobs=cur_logprobs, + routed_experts=routed_experts) outputs[session_id] = out if msg.return_logits: @@ -896,6 +903,10 @@ def __need_logits(seqs: SeqList): """Need logits.""" return any(seq.return_logits for seq in seqs) + def __need_routed_experts(seqs: SeqList): + """Need routed experts.""" + return any(seq.return_routed_experts for seq in seqs) + def __need_schedule_again(prefill: bool, scheduler_output): """Need schedule again.""" # only reschedule when prefill @@ -939,6 +950,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): inputs = self.create_model_inputs(running, prefill) sampling_inputs = self.sampling_strategy.make_sampling_inputs(running) return_logits = __need_logits(running) + return_routed_experts = __need_routed_experts(running) extra_inputs = self.model_agent_strategy.make_extra_inputs(running) stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) @@ -956,6 +968,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): is_dummy=False, sync_long_context=sync_long_context, extra_inputs=extra_inputs, + return_routed_experts=return_routed_experts, ) async def _await_forward_event(self, forward_event: asyncio.Event): @@ -991,6 +1004,7 @@ def __send_resp(out: InferOutput): logits=out.logits, cache_block_ids=out.cache_block_ids, req_metrics=out.req_metrics, + routed_experts=out.routed_experts, logprobs=logprobs)) def __update_logprobs(step_outputs: List[InferOutput]): diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 66cc2e6883..7d71a83ee2 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -152,6 +152,8 @@ async def async_stream_infer(self, cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None req_metrics = resp.data.get('req_metrics', None) if resp.data else None logprobs = resp.data.pop('logprobs', None) if resp.data else None + routed_experts = resp.data.get('routed_experts', None) if resp.data else None + if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'].tolist() num_ids = len(token_ids) - output_offset @@ -160,6 +162,7 @@ async def async_stream_infer(self, token_ids[output_offset:], cache_block_ids=cache_block_ids, req_metrics=req_metrics, + routed_experts=routed_experts, logprobs=logprobs) output_offset = len(token_ids) elif resp.type == ResponseType.FINISH: @@ -173,6 +176,7 @@ async def async_stream_infer(self, logits=logits, cache_block_ids=cache_block_ids, req_metrics=req_metrics, + routed_experts=routed_experts, logprobs=logprobs) break else: diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index ba8333520b..82792f4927 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -101,6 +101,10 @@ def release(self): """Release resources.""" raise NotImplementedError('Not Implemented.') + def serialize(self, obj): + """Serialize obj.""" + return obj + async def forward_async(self, inputs): """Start forward.""" raise NotImplementedError('Not Implemented') diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..1e473dd6a5 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio +import base64 import contextlib import json import os @@ -351,6 +352,13 @@ def wakeup(self, tags: Optional[List[str]] = None): self.update_configs() self.collective_rpc('wakeup', (tags, )) + def serialize(self, obj) -> str: + """Serialize obj.""" + ref = ray.put(obj) + data = ray.cloudpickle.dumps(ref) + data = base64.b64encode(data).decode('utf-8') + return data + def get_input_processor(self): """Build cache engine.""" return ray.get(self.workers[0].get_input_processor.remote()) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 78368bfed9..5e47bc9464 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -72,6 +72,7 @@ class BatchedOutputs: logprobs: Optional[BatchedLogProbs] = None new_token_timestamp: int = 0 extra_outputs: Optional[ExtraOutputs] = None + all_routed_experts: Optional[torch.Tensor] = None def to_cpu(self): """To cpu.""" @@ -238,7 +239,6 @@ def model_forward( kv_quant_policy=cache_engine.cache_config.quant_policy, ) with ctx_mgr.context(context): - model_metas = None model_metas = model.update_model_metas( past_key_values=cache_engine.gpu_cache, context=context, @@ -248,13 +248,14 @@ def model_forward( context=context, ) output = model(**input_dict) - + if not isinstance(output, Dict): + output = dict(hidden_states=output) # InternVL-3.5-Flash will change the seqlen, model_metas during forward if context.model_metas is not None and context.model_metas[0] is not None: model_metas = context.model_metas - seq_length = context.q_seqlens[:len(inputs.seq_length)] - - return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) + output['model_metas'] = model_metas + output['seq_length'] = context.q_seqlens[:len(inputs.seq_length)] + return output def _try_to_cuda(val, non_blocking: bool = False): @@ -444,6 +445,7 @@ async def _async_model_forward( inputs: ModelInputs, return_logits: bool, sync_long_context: bool, + return_routed_experts: bool, ): """Model forward.""" max_prefill_token_num = self.cache_config.max_prefill_token_num @@ -457,10 +459,23 @@ def __init__(self, max_seq_len): self._start = 0 self._output: torch.Tensor = None self._device: torch.device = None + self._routed_experts: torch.Tensor = None def gather(self, output): """gather.""" tmp_output = output['hidden_states'] + seq_len = tmp_output.size(-2) + + if return_routed_experts and 'all_routed_experts' in output: + tmp_exp_ids = output['all_routed_experts'] + out_exp_ids = self._routed_experts + if out_exp_ids is None: + out_exp_ids = tmp_exp_ids.new_empty(self._max_seq_len, *tmp_exp_ids.shape[1:], device='cpu') + self._device = tmp_output.device + out_exp_ids[self._start:self._start + seq_len, ...].copy_(tmp_exp_ids, non_blocking=True) + self._routed_experts = out_exp_ids + if not return_logits: + self._start += seq_len if not return_logits: self._output = tmp_output @@ -468,7 +483,7 @@ def gather(self, output): out_logits = self._output start = self._start - seq_len = tmp_output.size(-2) + if out_logits is None: out_logits = tmp_output.new_empty(1, self._max_seq_len, tmp_output.size(-1), device='cpu') self._device = tmp_output.device @@ -478,14 +493,23 @@ def gather(self, output): def get_output(self): """Get tmp_output.""" - if not return_logits: + if not (return_logits or return_routed_experts): seqlen = torch.full((1, ), self._output.numel() // self._output.size(-1), device=self._output.device, dtype=self._output.dtype) - return strategy.slice_outputs(self._output, seqlen) - torch.cuda.synchronize() - return self._output.to(self._device) + return strategy.slice_outputs(self._output, seqlen), None + else: + if return_logits: + torch.cuda.synchronize() + output_hidden_states = self._output.to(self._device) + else: + seqlen = torch.full((1, ), + self._output.numel() // self._output.size(-1), + device=self._output.device, + dtype=self._output.dtype) + output_hidden_states = strategy.slice_outputs(self._output, seqlen) + return output_hidden_states, self._routed_experts __forward = self.async_forward @@ -503,7 +527,11 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): model_metas = tmp_out.get('model_metas') output_gather.gather(tmp_out) tmp_out.pop('hidden_states', None) - tmp_out['hidden_states'] = output_gather.get_output() + tmp_out.pop('all_routed_experts', None) + tmp_out['hidden_states'], routed_experts = output_gather.get_output() + + if return_routed_experts: + tmp_out['all_routed_experts'] = routed_experts return tmp_out origin_inputs = inputs @@ -595,6 +623,7 @@ async def _async_step_background( sampling_inputs: SamplingInputs = None, stopping_criteria: StoppingCriteria = None, return_logits: bool = False, + return_routed_experts: bool = False, is_dummy: bool = False, sync_long_context: bool = False, extra_inputs: ExtraInputs = None, @@ -707,6 +736,7 @@ async def __prepare_dp(): inputs, return_logits=return_logits, sync_long_context=sync_long_context, + return_routed_experts=return_routed_experts and need_output, ) logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] @@ -731,6 +761,8 @@ async def __prepare_dp(): # post sampling next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, extra_inputs) + # for router replay + all_routed_experts = output.get('all_routed_experts', None) with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') @@ -751,6 +783,7 @@ async def __prepare_dp(): stop_pos=stop_pos, model_metas=model_metas, logprobs=logprobs, + all_routed_experts=all_routed_experts, extra_outputs=extra_outputs)) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, @@ -934,9 +967,16 @@ def _build_model(self): if custom_module_map is not None: update_custom_module_map(custom_module_map) logger.debug(msg_with_rank(rank, 'build model.')) - build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder, - dllm_config=self.misc_config.dllm_config, - strategy_factory=self.strategy_factory) + # for router replay + need_output = self.dist_ctx.dp > 1 or self.dist_ctx.rank % self.dist_ctx.tp == 0 + enable_return_routed_experts = self.misc_config.enable_return_routed_experts and need_output + + build_model_ctx = BuildModelContext( + disable_vision_encoder=self.misc_config.disable_vision_encoder, + dllm_config=self.misc_config.dllm_config, + strategy_factory=self.strategy_factory, + enable_return_routed_experts=enable_return_routed_experts, + ) patched_model = build_patched_model(self.model_config, device=device, model_format=self.misc_config.model_format, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index ec5e6098b8..6ebac7ad5c 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -56,6 +56,7 @@ class SamplingParam: out_logits: bool = False out_last_hidden_states: bool = False num_logprobs: int = -1 + return_routed_experts: bool = False @classmethod def from_gen_config(cls, gen_config: GenerationConfig): @@ -121,21 +122,24 @@ def from_gen_config(cls, gen_config: GenerationConfig): if random_seed is None: import random random_seed = random.getrandbits(64) - return SamplingParam(top_p=top_p, - top_k=top_k, - min_p=min_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=gen_config.ignore_eos, - random_seed=random_seed, - stop_words=stop_words, - bad_words=bad_words, - response_format=response_format, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - logits_processors=gen_config.logits_processors, - out_logits=(output_logits is not None), - num_logprobs=logprobs) + return SamplingParam( + top_p=top_p, + top_k=top_k, + min_p=min_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + ignore_eos=gen_config.ignore_eos, + random_seed=random_seed, + stop_words=stop_words, + bad_words=bad_words, + response_format=response_format, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + logits_processors=gen_config.logits_processors, + out_logits=(output_logits is not None), + num_logprobs=logprobs, + return_routed_experts=gen_config.return_routed_experts, + ) class MessageStatus(enum.Enum): @@ -407,6 +411,69 @@ def copy(self): return self.clone() +class HistoryRouterExperts: + """History router experts.""" + ALLOC_SIZE = 64 + + def __init__(self, expert_ids: np.ndarray = None, dtype: np.dtype = np.int16): + self.dtype = dtype + if expert_ids is None: + self._expert_ids = None + self._num_real = 0 + else: + self._expert_ids = expert_ids.astype(dtype) + self._num_real = len(expert_ids) + + def reserve(self, size: int): + """Reserve cache.""" + if self._expert_ids is None: + return + num_tokens = len(self._expert_ids) + if num_tokens >= size: + return + reserve_size = _round_up(size - num_tokens, self.ALLOC_SIZE) + new_expert_ids = np.pad(self._expert_ids, ((0, reserve_size), (0, 0), (0, 0))) + self._expert_ids = new_expert_ids + + def get_real(self): + """Get real data.""" + if self._expert_ids is None: + return None + return self._expert_ids[:self._num_real] + + def resize(self, size: int): + """Set size.""" + assert size <= self._num_real + self._num_real = size + + def append(self, expert_ids: np.ndarray): + """Append token ids.""" + if self._expert_ids is None: + self._expert_ids = expert_ids.astype(self.dtype) + self._num_real = len(expert_ids) + return + num_tokens = len(expert_ids) + self.reserve(num_tokens + self._num_real) + slice_start = self._num_real + slice_end = slice_start + num_tokens + self._num_real += num_tokens + self._expert_ids[slice_start:slice_end] = expert_ids + + def __len__(self): + """Get length.""" + return self._num_real + + def clone(self): + """clone.""" + expert_ids = None if self._expert_ids is None else self.get_real().copy() + ret = HistoryRouterExperts(expert_ids=expert_ids, dtype=self.dtype) + return ret + + def copy(self): + """copy.""" + return self.clone() + + class HistoryMultiModals: def __init__(self, multimodals: MultiModalInputs = None): @@ -500,6 +567,9 @@ class SchedulerSequence: # For logging engine_events: List[EngineEvent] = field(default_factory=list) + # for router replay + all_routed_experts: HistoryRouterExperts = field(default_factory=HistoryRouterExperts) + def __post_init__(self): """Post init.""" self._seq_meta: SequenceMeta = self.session.seq_meta @@ -567,6 +637,21 @@ def generated_ids(self) -> np.ndarray: start = end - self.num_new_tokens return self.history_cache._token_ids[start:end] + @property + def return_routed_experts(self) -> bool: + return self.sampling_param.return_routed_experts + + @property + def routed_experts(self) -> np.ndarray: + if (not self.return_routed_experts) or self.all_routed_experts is None: + return None + + end = max(0, self.num_all_ids - 1) + if 0 < end <= len(self.all_routed_experts): + return self.all_routed_experts.get_real()[:end] + else: + return None + @property def num_history_ids(self): """Num history ids.""" diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 7b3a09de3d..c6271d993b 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -448,6 +448,7 @@ class BuildModelContext: disable_vision_encoder: bool = False dllm_config: DLLMConfig = None strategy_factory: 'StrategyFactoryBase' = None + enable_return_routed_experts: bool = False class StepContextManager: diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index d66ad10ebf..bdcc5b1d2a 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -14,7 +14,8 @@ from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMixin +from .patch import get_build_model_context +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin class Qwen3MoeAttention(nn.Module): @@ -220,12 +221,18 @@ def __init__(self, layer_idx=layer_idx, ) - def forward(self, hidden_states: torch.Tensor): + def forward( + self, + hidden_states: torch.Tensor, + all_routed_experts: torch.Tensor = None, + ): """forward.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) topk_weights, topk_ids = self.softmax_topk(router_logits) + if all_routed_experts is not None: + all_routed_experts[:, self.layer_idx, :] = topk_ids if get_dist_manager().current_context().dist_config.enable_eplb: topk_ids = EPLBManager.topk_ids_logical_to_physical(topk_ids, self.eplb_dispatch_info) out_states = self.experts( @@ -277,6 +284,7 @@ def forward( past_key_value: Optional[List[torch.FloatTensor]], residual: Optional[torch.Tensor] = None, attn_metadata: Any = None, + all_routed_experts: torch.Tensor = None, ): if residual is None: @@ -295,7 +303,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, all_routed_experts=all_routed_experts) outputs = (hidden_states, residual) return outputs @@ -346,6 +354,7 @@ def forward( past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, inputs_embeds: Optional[torch.FloatTensor] = None, + all_routed_experts: torch.Tensor = None, ): """Rewrite of LlamaModel.forward.""" @@ -370,6 +379,7 @@ def forward( past_key_value=past_key_value, residual=residual, attn_metadata=attn_metadata, + all_routed_experts=all_routed_experts, ) # norm @@ -405,6 +415,7 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + # build model self.model = Qwen3MoeModel(config, dtype=dtype, device=device) # build lm_head @@ -413,6 +424,9 @@ def __init__(self, bias=False, dtype=dtype, device=device) + # for router replay + bm_ctx = get_build_model_context() + self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts def forward( self, @@ -421,6 +435,7 @@ def forward( past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, + all_routed_experts: torch.Tensor = None, **kwargs, ): """Model forward, return logits.""" @@ -430,8 +445,11 @@ def forward( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, + all_routed_experts=all_routed_experts, ) - return hidden_states + if all_routed_experts is None: + return hidden_states + return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts) def get_logits(self, hidden_states: torch.Tensor): """Compute logits of the model output.""" @@ -441,6 +459,36 @@ def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + if self.enable_return_routed_experts: + input_buffers['all_routed_experts'] = input_ids.new_empty( + (max_tokens, self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.int16) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + + input_buffers = graph_meta.input_buffers + if self.enable_return_routed_experts: + new_inputs['all_routed_experts'] = input_buffers['all_routed_experts'] + return new_inputs + + def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: torch.Tensor, **kwargs): + """Get outputs from buffers.""" + num_tokens = input_ids.size(-1) + outputs = dict() + outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens] + if self.enable_return_routed_experts: + outputs['all_routed_experts'] = output_buffers['all_routed_experts'][:num_tokens, ...].clone() + return outputs + def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], @@ -461,6 +509,11 @@ def prepare_inputs_for_generation( inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + # expert ids + all_routed_experts = None + if self.enable_return_routed_experts: + all_routed_experts = input_ids.new_empty( + (input_ids.size(1), self.config.num_hidden_layers, self.config.num_experts_per_tok), dtype=torch.int16) # inputs of forward return dict( input_ids=input_ids, @@ -468,6 +521,7 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, + all_routed_experts=all_routed_experts, ) def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 5137f8f132..1416c00e8c 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -51,6 +51,15 @@ def support_cuda_graph( """Return True is model support cudagraph.""" return attn_metadata.is_decoding + def make_output_buffers(self, output): + """Make output buffers.""" + if isinstance(output, torch.Tensor): + output_buffers = dict(hidden_states=output) + else: + assert isinstance(output, Dict) + output_buffers = output + return output_buffers + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> BuffType: """Make cudagraph buffers from forward inputs.""" max_batches = graph_meta.max_batchs @@ -179,3 +188,10 @@ def update_context_cudagraph(self, graph_meta: CudaGraphMeta, context: StepConte context.q_seqlens = input_buffers['q_seqlens'] context.kv_seqlens = input_buffers['kv_seqlens'] context.q_start_loc = input_buffers['q_start_loc'] + + def get_outputs_cudagraph(self, output_buffers: Dict[str, torch.Tensor], input_ids: Tensor, **kwargs): + """Get outputs from buffers.""" + num_tokens = input_ids.size(-1) + outputs = dict() + outputs['hidden_states'] = output_buffers['hidden_states'][:, :num_tokens] + return outputs diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 94429dae3f..7cf689c950 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -17,10 +17,12 @@ SeqList = List[SchedulerSequence] +@dataclass class ARExtraInputs(ExtraInputs): """Ar extra inputs.""" +@dataclass class ARExtraOutputs(ExtraOutputs): """Ar extra outputs.""" diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 91a3335f18..197217c8bb 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional +import numpy as np from torch import Tensor from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest @@ -24,6 +25,7 @@ def update_token_ids(self, embeddings: List[InputEmbeddings] = None, model_meta: Dict[str, Any] = None, mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + routed_experts: np.ndarray = None, **kwargs): """Update token ids, old token ids will be added to history.""" # update history image nums @@ -35,6 +37,10 @@ def update_token_ids(self, token_ids = _to_ndarray(token_ids) num_valid = len(token_ids) + # record cached expert ids + if self.return_routed_experts: + if routed_experts is not None: + self.all_routed_experts.append(routed_experts) if mode == UpdateTokenMode.INPUTS: self.arrive_time = time.perf_counter() @@ -72,6 +78,9 @@ def set_step(self, step: int): self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) + if self.return_routed_experts: + self.all_routed_experts.resize(step) + class ARSequenceStrategy(SequenceStrategy): @@ -84,13 +93,15 @@ def make_sequence(self, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Make sequence.""" - return SchedulerSequenceDefault(seq_id=seq_id, - session=session, - sampling_param=sampling_param, - adapter_name=adapter_name, - migration_request=migration_request, - resp_cache=resp_cache, - preserve_cache=preserve_cache) + return SchedulerSequenceDefault( + seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache, + ) def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: """Update running sequences.""" @@ -102,12 +113,22 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d model_metas = [None] * len(running) next_token_ids = next_token_ids.numpy() + all_routed_experts = [None] * len(running) + if is_decoding: + num_tokens = [1] * len(running) + else: + num_tokens = [msg.num_token_ids for msg in running] + + if batched_outputs.all_routed_experts is not None: + all_routed_experts = batched_outputs.all_routed_experts.split(num_tokens, dim=0) + all_routed_experts = [experts.numpy() for experts in all_routed_experts] update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL - for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): + for token, msg, stop, model_meta, routed_experts in zip(next_token_ids, running, stopped, model_metas, + all_routed_experts): if msg.status != MessageStatus.LOCKED: continue # fill token - msg.update_token_ids(token, model_meta=model_meta, mode=update_mode) + msg.update_token_ids(token, model_meta=model_meta, mode=update_mode, routed_experts=routed_experts) if stop: msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 792e1992e3..211c2fdd13 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -98,6 +98,8 @@ class GenOut: # for disaggregation cache_block_ids: List[int] = None + routed_experts: Any = None + def _gen_out_to_response(out: GenOut, index) -> Response: return Response(text=out.response, @@ -108,6 +110,7 @@ def _gen_out_to_response(out: GenOut, index) -> Response: logprobs=out.logprobs, last_hidden_state=out.last_hidden_state, logits=out.logits, + routed_experts=out.routed_experts, index=index) @@ -125,6 +128,7 @@ def _append_response(dst: Response, src: Response): if src.logprobs: dst.logprobs = dst.logprobs or [] dst.logprobs += src.logprobs + dst.routed_experts = src.routed_experts return dst @@ -903,6 +907,7 @@ def is_error(status): gen_len, finish_reason, token_ids=res, + routed_experts=outputs.routed_experts, cache_block_ids=outputs.cache_block_ids) if outputs.logprobs is not None: out.logprobs = (outputs.logprobs[:-hit_stop_token] if hit_stop_token else outputs.logprobs) @@ -933,6 +938,13 @@ def is_error(status): logits = outputs.logits[-1:] if outputs.logits else None last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None logprobs = outputs.logprobs[-1:] if outputs.logprobs else None + gen_len += 1 + + # router replay + routed_experts = outputs.routed_experts + if routed_experts is not None and not isinstance(routed_experts, str) and ( + not gen_config.include_stop_str_in_output) and finish_reason == 'stop': + routed_experts = routed_experts[:-1] logger.info(f'session {session_id} finished, reason ' f'"{finish_reason}", input_tokens ' @@ -946,6 +958,7 @@ def is_error(status): logprobs=logprobs, logits=logits, last_hidden_state=last_hidden_state, + routed_experts=routed_experts, cache_block_ids=outputs.cache_block_ids) # Update a session's sequence only when it is in finished status if outputs.status == ResponseType.FINISH: diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 4403b185d4..3458e6ace5 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -942,6 +942,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=request.spaces_between_special_tokens, include_stop_str_in_output=request.include_stop_str_in_output, + return_routed_experts=request.return_routed_experts, ) result_generator = VariableInterface.async_engine.generate( @@ -955,23 +956,33 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): do_preprocess=False, ) - def create_generate_response_json(res, text, output_ids, logprobs, finish_reason): + def create_generate_response_json(res, text, output_ids, logprobs, finish_reason, routed_experts=None): + # only output router experts in last chunk + routed_experts = None if finish_reason is None else routed_experts meta = GenerateReqMetaOutput(finish_reason=dict(type=finish_reason) if finish_reason else None, output_token_logprobs=logprobs or None, prompt_tokens=res.input_token_len, + routed_experts=routed_experts, completion_tokens=res.generate_token_len) - response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta) + + response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta, routed_experts=routed_experts) return response.model_dump_json() async def generate_stream_generator(): async for res in result_generator: text = res.response or '' output_ids = res.token_ids + routed_experts = res.routed_experts logprobs = [] if res.logprobs: for tok, tok_logprobs in zip(res.token_ids, res.logprobs): logprobs.append((tok_logprobs[tok], tok)) - response_json = create_generate_response_json(res, text, output_ids, logprobs, res.finish_reason) + response_json = create_generate_response_json(res, + text, + output_ids, + logprobs, + res.finish_reason, + routed_experts=routed_experts) yield f'data: {response_json}\n\n' yield 'data: [DONE]\n\n' @@ -998,6 +1009,7 @@ async def _inner_call(): meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None, output_token_logprobs=logprobs or None, prompt_tokens=res.input_token_len, + routed_experts=res.routed_experts, completion_tokens=res.generate_token_len) response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta) diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index e2371f0673..36d1ce9783 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -464,6 +464,7 @@ class GenerateReqInput(BaseModel): skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True include_stop_str_in_output: Optional[bool] = False + return_routed_experts: Optional[bool] = False class GenerateReqMetaOutput(BaseModel): @@ -471,6 +472,7 @@ class GenerateReqMetaOutput(BaseModel): completion_tokens: Optional[int] = None finish_reason: Optional[Dict[str, Any]] = None output_token_logprobs: Optional[List[tuple[float, int]]] = None # (logprob, token_id) + routed_experts: Optional[Union[List[List[List[int]]], str]] = None # (num_token, num_layer, topk_expert) # /generate output