Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class Engine:
def __init__(self, model_path: str, engine_config, csv: str):
self.pipe = pipeline(model_path, backend_config=engine_config, log_level='ERROR')
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

self.return_routed_experts = getattr(self.pipe.backend_config, 'enable_return_routed_experts', False)
self.csv = csv

def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output):
Expand All @@ -146,6 +146,7 @@ def process_request(self, requests, profiler: Profiler, temperature, top_p, top_
top_k=top_k,
ignore_eos=True,
do_sample=False,
return_routed_experts=self.return_routed_experts,
max_new_tokens=output_len) for _, _, output_len in requests
]

Expand Down Expand Up @@ -254,6 +255,7 @@ def parse_args():
# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.eager_mode(pt_group)
ArgumentHelper.enable_return_routed_experts(pt_group)

tp_act = ArgumentHelper.tp(pt_group)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
Expand Down Expand Up @@ -302,6 +304,7 @@ def main():
thread_safe=False,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
enable_return_routed_experts=args.enable_return_routed_experts,
)

engine = Engine(args.model_path, engine_config, csv=args.csv)
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def add_parser_api_server():
ArgumentHelper.dllm_unmasking_strategy(pt_group)
ArgumentHelper.dllm_denoising_steps(pt_group)
ArgumentHelper.dllm_confidence_threshold(pt_group)
ArgumentHelper.enable_return_routed_experts(pt_group)

# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
Expand Down Expand Up @@ -228,6 +229,7 @@ def api_server(args):
dllm_unmasking_strategy=args.dllm_unmasking_strategy,
dllm_denoising_steps=args.dllm_denoising_steps,
dllm_confidence_threshold=args.dllm_confidence_threshold,
enable_return_routed_experts=args.enable_return_routed_experts,
)
else:
from lmdeploy.messages import TurbomindEngineConfig
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,15 @@ def dllm_confidence_threshold(parser):
default=0.85,
help='The confidence threshold for dllm.')

@staticmethod
def enable_return_routed_experts(parser):
"""Add argument return routed experts to parser."""

return parser.add_argument('--enable-return-routed-experts',
action='store_true',
default=False,
help='Whether to output routed expert ids for replay')


# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
class FlexibleArgumentParser(argparse.ArgumentParser):
Expand Down
13 changes: 11 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class GenerationConfig:
preserve_cache: bool = False
migration_request: Optional[MigrationRequest] = None

# router replay
return_routed_experts: bool = False

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
stop_token_ids/bad_token_ids."""
Expand Down Expand Up @@ -376,6 +379,7 @@ class PytorchEngineConfig:
hf_overrides: Optional[Dict[str, Any]] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None
enable_return_routed_experts: bool = False

# dllm
dllm_block_length: int = None
Expand Down Expand Up @@ -457,14 +461,18 @@ class Response:
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
index: int = 0
routed_experts: Any = None

def __repr__(self):
logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}'
hidden_state = (
'last_hidden_state=None' if self.last_hidden_state is None else
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}')
routed_experts = 'routed_experts=None' if self.routed_experts is None else \
f'routed_experts.shape={self.routed_experts.shape}'

s = (f'text={self.text!r}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}\n{routed_experts}')
return s


Expand Down Expand Up @@ -544,6 +552,7 @@ class EngineOutput:
last_hidden_state: torch.Tensor = None
cache_block_ids: Optional[List[int]] = None
req_metrics: Optional[RequestMetrics] = None
routed_experts: torch.Tensor = None


@dataclass
Expand Down
19 changes: 4 additions & 15 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ def __init__(
self.pool = pool
self._graph: torch.cuda.CUDAGraph = None

def make_output_buffers(self, output):
"""Make output buffers."""
output_buffers = dict(logits=output)
return output_buffers

def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]):
"""Slice output."""
num_tokens = inputs['input_ids'].size(-1)
return output_buffers['logits'][:, :num_tokens]

@record_function('capture_cudagraph')
def capture(self, **kwargs):
"""Capture graph."""
Expand All @@ -113,17 +103,17 @@ def capture(self, **kwargs):

# warmup
warmup_output = self.model(**padded_kwargs)
warmup_buffers = self.make_output_buffers(warmup_output)
warmup_buffers = self.model.make_output_buffers(warmup_output)

self._graph = torch.cuda.CUDAGraph()
# unsafe kernel call in other thread might invalid the capture
# so we set thread_safe capture mode here.
with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):
output = self.model(**padded_kwargs)

output_buffers = self.make_output_buffers(output)
output_buffers = self.model.make_output_buffers(output)
self.meta.output_buffers = output_buffers
output = self.slice_output(warmup_buffers, kwargs)
output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs)
return output

@record_function('forward_cudagraph')
Expand All @@ -134,9 +124,8 @@ def forward(self, **kwargs):
context = self.ctx_mgr.current_context()
self.model.update_context_cudagraph(self.meta, context)
self._graph.replay()

output_buffers = self.meta.output_buffers
output = self.slice_output(output_buffers, kwargs)
output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
return output

def __del__(self):
Expand Down
20 changes: 12 additions & 8 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ class MiscConfig:
disable_vision_encoder: bool = False
logprobs_mode: str = None
dllm_config: DLLMConfig = None
enable_return_routed_experts: bool = False

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
Expand All @@ -356,12 +357,15 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
unmasking_strategy=dllm_unmasking_strategy,
denoising_steps=engine_config.dllm_denoising_steps,
confidence_threshold=engine_config.dllm_confidence_threshold)
misc_config = cls(custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides,
disable_vision_encoder=engine_config.disable_vision_encoder,
logprobs_mode=engine_config.logprobs_mode,
dllm_config=dllm_config)
misc_config = cls(
custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides,
disable_vision_encoder=engine_config.disable_vision_encoder,
logprobs_mode=engine_config.logprobs_mode,
dllm_config=dllm_config,
enable_return_routed_experts=engine_config.enable_return_routed_experts,
)
return misc_config
16 changes: 15 additions & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class InferOutput:
# for logging
req_metrics: RequestMetrics = None

# expert ids
routed_experts: torch.Tensor = None


def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
"""Tensorlize block_offsets."""
Expand Down Expand Up @@ -876,13 +879,17 @@ def _make_infer_outputs(
cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])

req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None
if routed_experts is not None:
routed_experts = self.executor.serialize(routed_experts)
out = InferOutput(session_id=session_id,
resp=msg.resp,
finish=finish,
token_ids=token_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=cur_logprobs)
logprobs=cur_logprobs,
routed_experts=routed_experts)
outputs[session_id] = out

if msg.return_logits:
Expand All @@ -896,6 +903,10 @@ def __need_logits(seqs: SeqList):
"""Need logits."""
return any(seq.return_logits for seq in seqs)

def __need_routed_experts(seqs: SeqList):
"""Need routed experts."""
return any(seq.return_routed_experts for seq in seqs)

def __need_schedule_again(prefill: bool, scheduler_output):
"""Need schedule again."""
# only reschedule when prefill
Expand Down Expand Up @@ -939,6 +950,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
inputs = self.create_model_inputs(running, prefill)
sampling_inputs = self.sampling_strategy.make_sampling_inputs(running)
return_logits = __need_logits(running)
return_routed_experts = __need_routed_experts(running)
extra_inputs = self.model_agent_strategy.make_extra_inputs(running)
stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)

Expand All @@ -956,6 +968,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
is_dummy=False,
sync_long_context=sync_long_context,
extra_inputs=extra_inputs,
return_routed_experts=return_routed_experts,
)

async def _await_forward_event(self, forward_event: asyncio.Event):
Expand Down Expand Up @@ -991,6 +1004,7 @@ def __send_resp(out: InferOutput):
logits=out.logits,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
routed_experts=out.routed_experts,
logprobs=logprobs))

def __update_logprobs(step_outputs: List[InferOutput]):
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ async def async_stream_infer(self,
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
logprobs = resp.data.pop('logprobs', None) if resp.data else None
routed_experts = resp.data.get('routed_experts', None) if resp.data else None

if resp.type == ResponseType.SUCCESS:
token_ids = resp.data['token_ids'].tolist()
num_ids = len(token_ids) - output_offset
Expand All @@ -160,6 +162,7 @@ async def async_stream_infer(self,
token_ids[output_offset:],
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
routed_experts=routed_experts,
logprobs=logprobs)
output_offset = len(token_ids)
elif resp.type == ResponseType.FINISH:
Expand All @@ -173,6 +176,7 @@ async def async_stream_infer(self,
logits=logits,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
routed_experts=routed_experts,
logprobs=logprobs)
break
else:
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def release(self):
"""Release resources."""
raise NotImplementedError('Not Implemented.')

def serialize(self, obj):
"""Serialize obj."""
return obj

async def forward_async(self, inputs):
"""Start forward."""
raise NotImplementedError('Not Implemented')
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import base64
import contextlib
import json
import os
Expand Down Expand Up @@ -351,6 +352,13 @@ def wakeup(self, tags: Optional[List[str]] = None):
self.update_configs()
self.collective_rpc('wakeup', (tags, ))

def serialize(self, obj) -> str:
"""Serialize obj."""
ref = ray.put(obj)
data = ray.cloudpickle.dumps(ref)
data = base64.b64encode(data).decode('utf-8')
return data

def get_input_processor(self):
"""Build cache engine."""
return ray.get(self.workers[0].get_input_processor.remote())
Expand Down
Loading