Skip to content

Commit

Permalink
Revert "add best_of and use_beam_search for completions interface" (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 11, 2023
1 parent 13f40b3 commit 77aa4df
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 79 deletions.
2 changes: 1 addition & 1 deletion fastchat/protocol/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class CompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: Union[UsageInfo, List[UsageInfo]]
usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,11 @@ class CompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None


class CompletionResponseChoice(BaseModel):
index: int
text: Union[str, List[str]]
text: str
logprobs: Optional[int] = None
finish_reason: Optional[Literal["stop", "length"]] = None

Expand Down
29 changes: 3 additions & 26 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,6 @@ async def get_gen_params(
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
n: Optional[int] = 1,
use_beam_search: Optional[bool] = None,
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
Expand Down Expand Up @@ -290,11 +287,6 @@ async def get_gen_params(
"stop_token_ids": conv.stop_token_ids,
}

if best_of is not None:
gen_params.update({"n": n, "best_of": best_of})
if use_beam_search is not None:
gen_params.update({"use_beam_search": use_beam_search})

new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
Expand Down Expand Up @@ -502,18 +494,12 @@ async def create_completion(request: CompletionRequest):
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
n=request.n,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, worker_addr)
)
text_completions.append(content)
# when use with best_of, only need send one request
if request.best_of:
break

try:
all_tasks = await asyncio.gather(*text_completions)
Expand All @@ -533,18 +519,9 @@ async def create_completion(request: CompletionRequest):
finish_reason=content.get("finish_reason", "stop"),
)
)
idx = 0
while True:
info = content["usage"]
if isinstance(info, list):
info = info[idx]

task_usage = UsageInfo.parse_obj(info)

for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
idx += 1
break
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
Expand Down
70 changes: 21 additions & 49 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.serve.model_worker import (
BaseModelWorker,
logger,
Expand Down Expand Up @@ -75,9 +74,6 @@ async def generate_stream(self, params):
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)
n = params.get("n", 1)

# Handle stop_str
stop = set()
Expand All @@ -94,51 +90,27 @@ async def generate_stream(self, params):
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
try:
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
best_of=best_of,
)

results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
prompt_tokens = len(request_output.prompt_token_ids)
output_usage = []
for out in request_output.outputs:
completion_tokens = len(out.token_ids)
total_tokens = prompt_tokens + completion_tokens
output_usage.append(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)

if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]

if sampling_params.best_of is None:
text_outputs = [" ".join(text_outputs)]
ret = {"text": text_outputs, "error_code": 0, "usage": output_usage}
yield (json.dumps(ret) + "\0").encode()
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{e}",
"error_code": ErrorCode.PARAM_OUT_OF_RANGE,
"usage": {},
}
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
)
results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet
ret = {"text": text_outputs, "error_code": 0, "usage": {}}
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
Expand Down

0 comments on commit 77aa4df

Please sign in to comment.