|
32 | 32 | from lmdeploy.serve.async_engine import AsyncEngine |
33 | 33 | from lmdeploy.serve.openai.harmony_utils import GptOssChatParser |
34 | 34 | from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 |
35 | | -from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice, |
| 35 | +from lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice, |
36 | 36 | ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, |
37 | 37 | ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest, |
38 | 38 | CompletionResponse, CompletionResponseChoice, |
@@ -65,6 +65,7 @@ class VariableInterface: |
65 | 65 | # following is for tool parsers |
66 | 66 | tool_parser: Optional[ToolParser] = None |
67 | 67 | allow_terminate_by_client: bool = False |
| 68 | + enable_abort_handling: bool = False |
68 | 69 |
|
69 | 70 |
|
70 | 71 | router = APIRouter() |
@@ -954,18 +955,8 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): |
954 | 955 | do_preprocess=False, |
955 | 956 | ) |
956 | 957 |
|
957 | | - def create_finish_reason(finish_reason): |
958 | | - # TODO: add detail info |
959 | | - if not finish_reason: |
960 | | - return None |
961 | | - if finish_reason == 'length': |
962 | | - return dict(type='length') |
963 | | - if finish_reason == 'stop': |
964 | | - return dict(type='stop') |
965 | | - return dict(type='abort') |
966 | | - |
967 | 958 | def create_generate_response_json(res, text, output_ids, logprobs, finish_reason): |
968 | | - meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(finish_reason), |
| 959 | + meta = GenerateReqMetaOutput(finish_reason=dict(type=finish_reason) if finish_reason else None, |
969 | 960 | output_token_logprobs=logprobs or None, |
970 | 961 | prompt_tokens=res.input_token_len, |
971 | 962 | completion_tokens=res.generate_token_len) |
@@ -1004,7 +995,7 @@ async def _inner_call(): |
1004 | 995 | for tok, tok_logprobs in zip(res.token_ids, res.logprobs): |
1005 | 996 | logprobs.append((tok_logprobs[tok], tok)) |
1006 | 997 | nonlocal response |
1007 | | - meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(res.finish_reason), |
| 998 | + meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None, |
1008 | 999 | output_token_logprobs=logprobs or None, |
1009 | 1000 | prompt_tokens=res.input_token_len, |
1010 | 1001 | completion_tokens=res.generate_token_len) |
@@ -1168,6 +1159,21 @@ async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONRespo |
1168 | 1159 | """ PD Disaggregation API End """ |
1169 | 1160 |
|
1170 | 1161 |
|
| 1162 | +@router.post('/abort_request') |
| 1163 | +async def abort_request(request: AbortRequest, raw_request: Request = None): |
| 1164 | + """Abort an ongoing request.""" |
| 1165 | + if not VariableInterface.enable_abort_handling: |
| 1166 | + return Response( |
| 1167 | + status_code=501, |
| 1168 | + content='This server does not support abort requests. Enable with --enable-abort-handling flag.') |
| 1169 | + |
| 1170 | + if request.abort_all: |
| 1171 | + await VariableInterface.async_engine.stop_all_session() |
| 1172 | + else: |
| 1173 | + await VariableInterface.async_engine.stop_session(request.session_id) |
| 1174 | + return Response(status_code=200) |
| 1175 | + |
| 1176 | + |
1171 | 1177 | @router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)]) |
1172 | 1178 | async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = None): |
1173 | 1179 | return create_error_response( |
@@ -1332,6 +1338,7 @@ def serve(model_path: str, |
1332 | 1338 | reasoning_parser: Optional[str] = None, |
1333 | 1339 | tool_call_parser: Optional[str] = None, |
1334 | 1340 | allow_terminate_by_client: bool = False, |
| 1341 | + enable_abort_handling: bool = False, |
1335 | 1342 | **kwargs): |
1336 | 1343 | """An example to perform model inference through the command line |
1337 | 1344 | interface. |
@@ -1390,6 +1397,7 @@ def serve(model_path: str, |
1390 | 1397 | logger.setLevel(log_level) |
1391 | 1398 |
|
1392 | 1399 | VariableInterface.allow_terminate_by_client = allow_terminate_by_client |
| 1400 | + VariableInterface.enable_abort_handling = enable_abort_handling |
1393 | 1401 | if api_keys is not None: |
1394 | 1402 | if isinstance(api_keys, str): |
1395 | 1403 | api_keys = api_keys.split(',') |
|
0 commit comments