Skip to content

Commit

Permalink
fix deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 30, 2024
1 parent 9b078a1 commit 6971c15
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 14 deletions.
1 change: 0 additions & 1 deletion swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArgum
infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'

result_path: Optional[str] = None
writer_buffer_size: int = 65536
# for pt engine
max_batch_size: int = 1
ddp_backend: Optional[str] = None
Expand Down
7 changes: 5 additions & 2 deletions swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from contextlib import contextmanager
from dataclasses import asdict
from functools import partial
from http import HTTPStatus
from threading import Thread
from typing import List, Optional, Union
Expand Down Expand Up @@ -153,10 +154,12 @@ def pre_infer_hook(kwargs):
logger.info(request_info)
return kwargs

self.infer_engine.pre_infer_hooks = [pre_infer_hook]
infer_kwargs['pre_infer_hook'] = pre_infer_hook
try:
res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs)
except ValueError as e:
except Exception as e:
import traceback
print(traceback.format_exc())
return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if request_config.stream:

Expand Down
3 changes: 1 addition & 2 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def get_infer_engine(args: InferArguments, **kwargs):

def main(self):
args = self.args
context = open_jsonl_writer(
args.result_path, buffer_size=args.writer_buffer_size) if args.result_path else nullcontext()
context = open_jsonl_writer(args.result_path) if args.result_path else nullcontext()
with context as json_writer:
self.jsonl_writer = json_writer
return super().main()
Expand Down
9 changes: 6 additions & 3 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def _post_init(self):
self.model_name = self.model_info.model_name
self.max_model_len = self.model_info.max_model_len
self.config = self.model_info.config
self.pre_infer_hooks = []
if getattr(self, 'default_template', None) is None:
self.default_template = get_template(self.model_meta.template, self.processor)
self._adapters_pool = {}
Expand Down Expand Up @@ -60,7 +59,9 @@ async def _run_infer(i, task, queue, stream: bool = False):
queue.put((i, stream_response))
else:
queue.put((i, await task))
finally:
except Exception as e:
queue.put((i, e))
else:
queue.put((i, None))

async def _batch_run(tasks):
Expand All @@ -78,7 +79,9 @@ async def _batch_run(tasks):

while n_finished < len(new_tasks):
i, output = queue.get()
if output is None: # is_finished
if isinstance(output, Exception):
raise output
elif output is None: # is_finished
n_finished += 1
prog_bar.update()
else:
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/infer/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def infer_async(self,
request_config: Optional[RequestConfig] = None,
*,
template: Optional[Template] = None,
pre_infer_hook=None,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
Expand All @@ -275,7 +276,7 @@ async def infer_async(self,
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
kwargs.update({'template': template, 'inputs': inputs, 'generation_config': generation_config})
for pre_infer_hook in self.pre_infer_hooks:
if pre_infer_hook:
kwargs = pre_infer_hook(kwargs)
if request_config.stream:
return self._infer_stream_async(**kwargs)
Expand Down
9 changes: 5 additions & 4 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ async def infer_async(
# TODO:auto batch
if request_config is None:
request_config = RequestConfig()
res_or_gen = self.infer([infer_request],
res_or_gen = self._infer([infer_request],
request_config,
template=template,
use_tqdm=False,
adapter_request=adapter_request)
adapter_request=adapter_request,
pre_infer_hook=pre_infer_hook)
if request_config.stream:

async def _gen_wrapper():
Expand All @@ -376,6 +376,7 @@ def _infer(
*,
template: Optional[Template] = None,
adapter_request: Optional[AdapterRequest] = None,
pre_infer_hook=None,
) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
self.model.eval()
request_config = deepcopy(request_config)
Expand Down Expand Up @@ -414,7 +415,7 @@ def _infer(
'adapter_request': adapter_request,
'template_inputs': template_inputs
}
for pre_infer_hook in self.pre_infer_hooks:
if pre_infer_hook:
kwargs = pre_infer_hook(kwargs)
if request_config.stream:

Expand Down
3 changes: 2 additions & 1 deletion swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ async def infer_async(
*,
template: Optional[Template] = None,
adapter_request: Optional[AdapterRequest] = None,
pre_infer_hook=None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
Expand All @@ -381,7 +382,7 @@ async def infer_async(
'generation_config': generation_config,
'adapter_request': adapter_request
}
for pre_infer_hook in self.pre_infer_hooks:
if pre_infer_hook:
kwargs = pre_infer_hook(kwargs)
if request_config.stream:
return self._infer_stream_async(**kwargs)
Expand Down

0 comments on commit 6971c15

Please sign in to comment.