This repository has been archived by the owner on Apr 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain.py
100 lines (76 loc) · 2.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
print(
r"""
__ __ _ _____ _ _
\ \ / / | | / ____| | | |
\ \ /\ / /__| |__ | | | |__ __ _| |_
\ \/ \/ / _ \ '_ \| | | '_ \ / _` | __|
\ /\ / __/ |_) | |____| | | | (_| | |_
___\/__\/ \___|_.__/ \_____|_| |_|\__,_|\__|
| __ \ \ / / |/ /\ \ / / | | (_)
| |__) \ \ /\ / /| ' / \ \ / /__| |_ _ ___
| _ / \ \/ \/ / | < \ \/ / __| __| |/ __|
| | \ \ \ /\ / | . \ \ /\__ \ |_| | (__
|_| \_\ \/ \/ |_|\_\ \/ |___/\__|_|\___|
"""
)
print("Importing modules...")
import asyncio
from fastapi import FastAPI, WebSocket
from fastapi.staticfiles import StaticFiles
import model
app = FastAPI()
@app.websocket("/ws")
async def websocket(ws: WebSocket):
loop = asyncio.get_running_loop()
await ws.accept()
session = {"state": None}
async def reply(id, *, result=None, error=None):
either = (result is None) is not (error is None)
assert either, "Either result or error must be set!"
if result is not None:
await ws.send_json({"jsonrpc": "2.0", "result": result, "id": id})
elif error is not None:
await ws.send_json({"jsonrpc": "2.0", "error": error, "id": id})
def on_progress(id):
def callback(res):
asyncio.run_coroutine_threadsafe(reply(id, result={"token": res}), loop)
return callback
def on_done(input):
def callback(result):
print("--- input ---")
print(input)
print("--- output ---")
print(result["output"])
print("---")
session["state"] = result["state"]
return callback
while True:
data = await ws.receive_json()
if "jsonrpc" not in data or data["jsonrpc"] != "2.0":
await reply(
data.get("id", None) if type(data) == dict else None,
error="invalid message",
)
method, params, id = (
data.get("method", None),
data.get("params", None),
data.get("id", None),
)
if method == "chat":
text = params.get("text", None)
if text is None:
await reply(id, error="text is required")
await loop.run_in_executor(
None,
model.chat,
session["state"],
text,
on_progress(id),
on_done(text),
)
else:
await reply(id, error=f"invalid method '{method}'")
app.mount("/", StaticFiles(directory="static", html=True), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app)