Skip to content

Commit 9ec8c40

Browse files
authored
add endpoint /abort_request (#4092)
* add endpoint /abort_request * add finish_reason abort * enlarge num_instance * add option --enable-abort-handling * fix access None req_metrics when sending abort_request
1 parent a6aa375 commit 9ec8c40

File tree

10 files changed

+79
-24
lines changed

10 files changed

+79
-24
lines changed

lmdeploy/cli/serve.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def add_parser_api_server():
7373
ArgumentHelper.max_log_len(parser)
7474
ArgumentHelper.disable_fastapi_docs(parser)
7575
ArgumentHelper.allow_terminate_by_client(parser)
76+
ArgumentHelper.enable_abort_handling(parser)
7677
# chat template args
7778
ArgumentHelper.chat_template(parser)
7879

@@ -266,6 +267,7 @@ def api_server(args):
266267
allow_methods=args.allow_methods,
267268
allow_headers=args.allow_headers,
268269
allow_terminate_by_client=args.allow_terminate_by_client,
270+
enable_abort_handling=args.enable_abort_handling,
269271
log_level=args.log_level.upper(),
270272
api_keys=args.api_keys,
271273
ssl=args.ssl,
@@ -293,6 +295,7 @@ def api_server(args):
293295
allow_methods=args.allow_methods,
294296
allow_headers=args.allow_headers,
295297
allow_terminate_by_client=args.allow_terminate_by_client,
298+
enable_abort_handling=args.enable_abort_handling,
296299
log_level=args.log_level.upper(),
297300
api_keys=args.api_keys,
298301
ssl=args.ssl,

lmdeploy/cli/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,16 @@ def allow_terminate_by_client(parser):
454454
default=False,
455455
help='Enable server to be terminated by request from client')
456456

457+
@staticmethod
458+
def enable_abort_handling(parser):
459+
"""Add --enable-abort-handling argument to configure server abort
460+
request processing."""
461+
462+
return parser.add_argument('--enable-abort-handling',
463+
action='store_true',
464+
default=False,
465+
help='Enable server to handle client abort requests')
466+
457467
@staticmethod
458468
def cache_max_entry_count(parser):
459469
"""Add argument cache_max_entry_count to parser."""

lmdeploy/metrics/metrics_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ async def _run_metrics_handler(self):
122122
outputs, req_state, iteration_stats = update_data
123123

124124
# update request state according the engine events
125-
req_state.update_from_events(outputs.req_metrics.engine_events)
125+
if outputs and outputs.req_metrics:
126+
# when users visit "/abort_request" endpoint, `req_metrics` might be None
127+
req_state.update_from_events(outputs.req_metrics.engine_events)
126128

127129
# update iteration stats based on outputs and request state.
128130
# some attributes of req_state will also be updated, e.g., lastest_token_time

lmdeploy/metrics/stats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
198198
outputs (EngineOutput): The output from the engine containing information about the current iteration.
199199
req_state (RequestState): The state of the request, including timestamps and token counts.
200200
"""
201+
if outputs.req_metrics is None:
202+
# when users visit "/abort_request" endpoint, `req_metrics` might be None
203+
return
201204
new_generation_tokens = len(outputs.token_ids)
202205
if new_generation_tokens == 0:
203206
return

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs):
540540
for seq in session.sequences.values():
541541
_resp: Response = getattr(seq, 'resp', None)
542542
if _resp is not None:
543-
_resp.type = ResponseType.FINISH
543+
_resp.type = ResponseType.CANCEL
544544
self.req_manager.response(_resp)
545545
resp_type = ResponseType.SUCCESS
546546
if resp:

lmdeploy/pytorch/engine/mp_engine/base_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class EngineInstancePool:
2020
def __init__(self, engine):
2121
from lmdeploy.pytorch.engine import Engine
2222
self.engine: Engine = engine
23-
self.num_instance = self.engine.engine_config.max_batch_size
23+
# enlarge `num_instance`, otherwise an sequence cannot be stopped in time
24+
self.num_instance = self.engine.engine_config.max_batch_size * 2
2425
self.pool = None
2526

2627
def create_instance_pool(self, num_instance: int):

lmdeploy/serve/async_engine.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,26 @@ async def do_log_stats(self):
444444
for stat_logger in self.stat_loggers:
445445
stat_logger.log()
446446

447+
async def stop_all_session(self):
448+
"""Stop all running sessions."""
449+
logger.info('stop all sessions')
450+
tasks = []
451+
session_ids = []
452+
for session_id in list(self.id2inst.keys()):
453+
generator = self.id2inst.get(session_id)
454+
if generator:
455+
session_ids.append(session_id)
456+
tasks.append(generator.async_cancel(session_id))
457+
await asyncio.gather(*tasks)
458+
logger.info(f'all {len(session_ids)} sessions stopped')
459+
447460
async def stop_session(self, session_id: int):
448461
"""Stop a session by a session_id."""
449462
logger.info(f'stop session {session_id}')
450463
generator = self.id2inst.get(session_id)
451464
if generator:
452465
await generator.async_cancel(session_id)
466+
logger.info(f'session {session_id} stopped')
453467
# else it's not running at all
454468

455469
async def end_session(self, session_id: int):
@@ -855,7 +869,7 @@ def is_error(status):
855869
break
856870

857871
output_len = len(outputs.token_ids)
858-
if hit_stop_token:
872+
if hit_stop_token or output_len == 0:
859873
continue
860874

861875
# This assumes the engine will stop when stop token is hit
@@ -892,7 +906,11 @@ def is_error(status):
892906
metrics_processor.increment_finished_requests()
893907

894908
if not is_error(outputs.status):
895-
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
909+
if outputs.status == ResponseType.CANCEL:
910+
finish_reason = 'abort'
911+
else:
912+
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
913+
896914
# utf-8 char at the end means it's a potential unfinished byte sequence
897915
if not response.endswith('�'):
898916
# avoid returning the last response twice
@@ -926,7 +944,7 @@ def is_error(status):
926944
output_len = gen_len
927945
self.id2step[session_id] += input_len + output_len
928946
else:
929-
logger.error(f'session {session_id} finished, '
947+
logger.error(f'session {session_id} finished, {outputs.status}, '
930948
'reason "error"')
931949
yield GenOut(response=f'internal error happened, status code {outputs.status}',
932950
history_token_len=self.id2step[session_id],

lmdeploy/serve/openai/api_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_model_list(api_url: str, headers: dict = None):
1313
logger = get_logger('lmdeploy')
1414
if not response.ok:
1515
logger.error(f'Failed to get the model list: {api_url}'
16-
'returns {response.status_code}')
16+
f' returns {response.status_code}')
1717
return None
1818
elif not hasattr(response, 'text'):
1919
logger.warning('Failed to get the model list.')

lmdeploy/serve/openai/api_server.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from lmdeploy.serve.async_engine import AsyncEngine
3333
from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
3434
from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501
35-
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice,
35+
from lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice,
3636
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
3737
ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest,
3838
CompletionResponse, CompletionResponseChoice,
@@ -65,6 +65,7 @@ class VariableInterface:
6565
# following is for tool parsers
6666
tool_parser: Optional[ToolParser] = None
6767
allow_terminate_by_client: bool = False
68+
enable_abort_handling: bool = False
6869

6970

7071
router = APIRouter()
@@ -954,18 +955,8 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
954955
do_preprocess=False,
955956
)
956957

957-
def create_finish_reason(finish_reason):
958-
# TODO: add detail info
959-
if not finish_reason:
960-
return None
961-
if finish_reason == 'length':
962-
return dict(type='length')
963-
if finish_reason == 'stop':
964-
return dict(type='stop')
965-
return dict(type='abort')
966-
967958
def create_generate_response_json(res, text, output_ids, logprobs, finish_reason):
968-
meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(finish_reason),
959+
meta = GenerateReqMetaOutput(finish_reason=dict(type=finish_reason) if finish_reason else None,
969960
output_token_logprobs=logprobs or None,
970961
prompt_tokens=res.input_token_len,
971962
completion_tokens=res.generate_token_len)
@@ -1004,7 +995,7 @@ async def _inner_call():
1004995
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
1005996
logprobs.append((tok_logprobs[tok], tok))
1006997
nonlocal response
1007-
meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(res.finish_reason),
998+
meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None,
1008999
output_token_logprobs=logprobs or None,
10091000
prompt_tokens=res.input_token_len,
10101001
completion_tokens=res.generate_token_len)
@@ -1168,6 +1159,21 @@ async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONRespo
11681159
""" PD Disaggregation API End """
11691160

11701161

1162+
@router.post('/abort_request')
1163+
async def abort_request(request: AbortRequest, raw_request: Request = None):
1164+
"""Abort an ongoing request."""
1165+
if not VariableInterface.enable_abort_handling:
1166+
return Response(
1167+
status_code=501,
1168+
content='This server does not support abort requests. Enable with --enable-abort-handling flag.')
1169+
1170+
if request.abort_all:
1171+
await VariableInterface.async_engine.stop_all_session()
1172+
else:
1173+
await VariableInterface.async_engine.stop_session(request.session_id)
1174+
return Response(status_code=200)
1175+
1176+
11711177
@router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
11721178
async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = None):
11731179
return create_error_response(
@@ -1332,6 +1338,7 @@ def serve(model_path: str,
13321338
reasoning_parser: Optional[str] = None,
13331339
tool_call_parser: Optional[str] = None,
13341340
allow_terminate_by_client: bool = False,
1341+
enable_abort_handling: bool = False,
13351342
**kwargs):
13361343
"""An example to perform model inference through the command line
13371344
interface.
@@ -1390,6 +1397,7 @@ def serve(model_path: str,
13901397
logger.setLevel(log_level)
13911398

13921399
VariableInterface.allow_terminate_by_client = allow_terminate_by_client
1400+
VariableInterface.enable_abort_handling = enable_abort_handling
13931401
if api_keys is not None:
13941402
if isinstance(api_keys, str):
13951403
api_keys = api_keys.split(',')

lmdeploy/serve/openai/protocol.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
256256
index: int
257257
delta: DeltaMessage
258258
logprobs: Optional[ChoiceLogprobs] = None
259-
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
259+
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None
260260

261261

262262
class ChatCompletionStreamResponse(BaseModel):
@@ -314,7 +314,7 @@ class CompletionResponseChoice(BaseModel):
314314
text: str
315315
logprobs: Optional[LogProbs] = None
316316
gen_tokens: Optional[List[int]] = None
317-
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
317+
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None
318318

319319

320320
class CompletionResponse(BaseModel):
@@ -333,7 +333,7 @@ class CompletionResponseStreamChoice(BaseModel):
333333
text: str
334334
logprobs: Optional[LogProbs] = None
335335
gen_tokens: Optional[List[int]] = None
336-
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
336+
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None
337337

338338

339339
class CompletionStreamResponse(BaseModel):
@@ -430,7 +430,7 @@ class GenerateResponse(BaseModel):
430430
tokens: int
431431
input_tokens: int
432432
history_tokens: int
433-
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
433+
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None
434434

435435

436436
class UpdateParamsRequest(BaseModel):
@@ -478,3 +478,13 @@ class GenerateReqOutput(BaseModel):
478478
text: str
479479
output_ids: List[int]
480480
meta_info: GenerateReqMetaOutput
481+
482+
483+
class AbortRequest(BaseModel):
484+
# Whether to abort all requests
485+
abort_all: bool = False
486+
# The finished reason data
487+
finished_reason: Optional[Dict[str, Any]] = None
488+
abort_message: Optional[str] = None
489+
# The session ID to abort. If `abort_all` is True, this field is ignored.
490+
session_id: Optional[int] = -1

0 commit comments

Comments
 (0)