Skip to content

Commit

Permalink
support inference a batch of prompts (#467)
Browse files Browse the repository at this point in the history
* support inference a batch of prompts

* docstring and assert
  • Loading branch information
AllentDan authored Oct 25, 2023
1 parent 169d516 commit ac3500b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
68 changes: 61 additions & 7 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path as osp
import random
from contextlib import contextmanager
from typing import Literal, Optional
from typing import List, Literal, Optional

from lmdeploy.model import MODELS, BaseModel

Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None:
self.available = [True] * instance_num
self.starts = [None] * instance_num
self.steps = {}
self.loop = asyncio.get_event_loop()

def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num
Expand Down Expand Up @@ -82,6 +83,59 @@ async def get_generator(self, instance_id: int, stop: bool = False):
await asyncio.sleep(0.1)
return self.generators[instance_id]

def batch_infer(self,
prompts: List[str],
request_output_len=512,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
**kwargs):
"""Inference a batch of prompts.
Args:
prompts (List[str]): a batch of prompts
request_output_len (int): output token nums
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
assert isinstance(prompts, List), 'prompts should be a list'
batch_size = len(prompts)
outputs = [''] * batch_size
generators = []
for i, prompt in enumerate(prompts):
generators.append(
self.generate(prompt,
i,
stream_response=True,
sequence_start=True,
sequence_end=True,
request_output_len=request_output_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=ignore_eos,
repetition_penalty=repetition_penalty))

async def _inner_call(i, generator):
async for out in generator:
outputs[i] += out.response

async def gather():
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(batch_size)])

self.loop.run_until_complete(gather())
return outputs

async def generate(
self,
messages,
Expand Down Expand Up @@ -109,11 +163,11 @@ async def generate(
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Expand Down Expand Up @@ -195,11 +249,11 @@ async def generate_openai(
renew_session (bool): renew the session
request_output_len (int): output token nums
stop (bool): whether stop inference
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Expand Down
18 changes: 13 additions & 5 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,19 @@ async def create_embeddings(request: EmbeddingsRequest,
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret

embedding = await VariableInterface.async_engine.get_embeddings(
request.input)
data = [{'object': 'embedding', 'embedding': embedding, 'index': 0}]
token_num = len(embedding)
if isinstance(request.input, str):
request.input = [request.input]

data = []
token_num = 0
for i, prompt in enumerate(request.input):
embedding = await VariableInterface.async_engine.get_embeddings(prompt)
data.append({
'object': 'embedding',
'embedding': embedding,
'index': i
})
token_num += len(embedding)
return EmbeddingsResponse(
data=data,
model=request.model,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class CompletionStreamResponse(BaseModel):
class EmbeddingsRequest(BaseModel):
"""Embedding request."""
model: str = None
input: Union[str, List[Any]]
input: Union[str, List[str]]
user: Optional[str] = None


Expand Down

0 comments on commit ac3500b

Please sign in to comment.