Skip to content

Commit

Permalink
update server api to work with chatbotui (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
nuance1979 committed Apr 30, 2023
1 parent 9f184f1 commit 7bc9a2d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
29 changes: 19 additions & 10 deletions llama_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Conversation(BaseModel):
class Choice(BaseModel):
message: Optional[Message] = None
delta: Optional[Message] = None
finish_reason: Optional[str] = None


class Completion(BaseModel):
Expand All @@ -67,22 +68,27 @@ class ModelList(BaseModel):
app = FastAPI()


def _chat(user_utt: str) -> Generator[str, None, None]:
return model.generate(user_utt, n_predict=256, repeat_penalty=1.0, n_threads=8)
def _chat(user_utt: str, temperature: float) -> Generator[str, None, None]:
return model.generate(
user_utt, n_predict=256, repeat_penalty=1.0, n_threads=8, temp=temperature
)


def chat_stream(user_utt: str) -> Generator[Dict[str, Any], None, None]:
for text in _chat(user_utt):
def chat_stream(
user_utt: str, temperature: float
) -> Generator[Dict[str, Any], None, None]:
for text in _chat(user_utt, temperature):
logger.debug("text: %s", text)
payload = Completion(
choices=[Choice(delta=Message(role="assistant", content=text))]
)
yield {"event": "event", "data": payload.json()}
yield {"event": "event", "data": "[DONE]"}
payload = Completion(choices=[Choice(finish_reason="stop")])
yield {"event": "event", "data": payload.json()}


def chat_nonstream(user_utt: str) -> Completion:
assistant_utt = "".join(_chat(user_utt))
def chat_nonstream(user_utt: str, temperature: float) -> Completion:
assistant_utt = "".join(_chat(user_utt, temperature))
logger.info("assistant: %s", assistant_utt)
return Completion(
choices=[Choice(message=Message(role="assistant", content=assistant_utt))]
Expand All @@ -92,11 +98,14 @@ def chat_nonstream(user_utt: str) -> Completion:
@app.post("/v1/chat/completions")
def chat(conv: Conversation):
user_utt = conv.messages[-1].content
logger.info("user: %s", user_utt)
temperature = conv.temperature
logger.info("user: %s temperature: %f", user_utt, temperature)
if not conv.stream:
return chat_nonstream(user_utt)
return chat_nonstream(user_utt, temperature)
else:
return EventSourceResponse(chat_stream(user_utt), ping_message_factory=None)
return EventSourceResponse(
chat_stream(user_utt, temperature), ping_message_factory=None
)


@app.get("/v1/models")
Expand Down
6 changes: 5 additions & 1 deletion test/llama_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def testPostChatStreaming(self):
from json import loads

datalines = [line for line in response.iter_lines() if line.startswith("data")]
self.assertEqual(len(MockModel.tokens) + 1, len(datalines))
for line, tok in zip(datalines, MockModel.tokens):
json = loads(line[6:])
json = loads(line[6:]) # skip prefix `data: `
self.assertEqual(1, len(json["choices"]))
self.assertEqual("assistant", json["choices"][0]["delta"]["role"])
self.assertEqual(tok, json["choices"][0]["delta"]["content"])
json = loads(datalines[-1][6:])
self.assertEqual(1, len(json["choices"]))
self.assertEqual("stop", json["choices"][0]["finish_reason"])

0 comments on commit 7bc9a2d

Please sign in to comment.