diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 13a5f496a1..4aba9dce7f 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -4,7 +4,7 @@ import os.path as osp import random from contextlib import contextmanager -from typing import Literal, Optional +from typing import List, Literal, Optional from lmdeploy.model import MODELS, BaseModel @@ -46,6 +46,7 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None: self.available = [True] * instance_num self.starts = [None] * instance_num self.steps = {} + self.loop = asyncio.get_event_loop() def stop_session(self, session_id: int): instance_id = session_id % self.instance_num @@ -82,6 +83,59 @@ async def get_generator(self, instance_id: int, stop: bool = False): await asyncio.sleep(0.1) return self.generators[instance_id] + def batch_infer(self, + prompts: List[str], + request_output_len=512, + top_k=40, + top_p=0.8, + temperature=0.8, + repetition_penalty=1.0, + ignore_eos=False, + **kwargs): + """Inference a batch of prompts. + + Args: + prompts (List[str]): a batch of prompts + request_output_len (int): output token nums + top_k (int): The number of the highest probability vocabulary + tokens to keep for top-k-filtering + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + temperature (float): to modulate the next token probability + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + ignore_eos (bool): indicator for ignoring eos + """ + assert isinstance(prompts, List), 'prompts should be a list' + batch_size = len(prompts) + outputs = [''] * batch_size + generators = [] + for i, prompt in enumerate(prompts): + generators.append( + self.generate(prompt, + i, + stream_response=True, + sequence_start=True, + sequence_end=True, + request_output_len=request_output_len, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ignore_eos=ignore_eos, + repetition_penalty=repetition_penalty)) + + async def _inner_call(i, generator): + async for out in generator: + outputs[i] += out.response + + async def gather(): + await asyncio.gather( + *[_inner_call(i, generators[i]) for i in range(batch_size)]) + + self.loop.run_until_complete(gather()) + return outputs + async def generate( self, messages, @@ -109,11 +163,11 @@ async def generate( sequence_end (bool): indicator for ending a sequence step (int): the offset of the k/v cache stop (bool): whether stop inference - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. temperature (float): to modulate the next token probability repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty @@ -195,11 +249,11 @@ async def generate_openai( renew_session (bool): renew the session request_output_len (int): output token nums stop (bool): whether stop inference - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. top_k (int): The number of the highest probability vocabulary tokens to keep for top-k-filtering + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. temperature (float): to modulate the next token probability repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 8324e3497f..3ba8b80b4b 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest, error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret - - embedding = await VariableInterface.async_engine.get_embeddings( - request.input) - data = [{'object': 'embedding', 'embedding': embedding, 'index': 0}] - token_num = len(embedding) + if isinstance(request.input, str): + request.input = [request.input] + + data = [] + token_num = 0 + for i, prompt in enumerate(request.input): + embedding = await VariableInterface.async_engine.get_embeddings(prompt) + data.append({ + 'object': 'embedding', + 'embedding': embedding, + 'index': i + }) + token_num += len(embedding) return EmbeddingsResponse( data=data, model=request.model, diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 756af1a4ca..78bf56531b 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel): class EmbeddingsRequest(BaseModel): """Embedding request.""" model: str = None - input: Union[str, List[Any]] + input: Union[str, List[str]] user: Optional[str] = None