Skip to content

Commit

Permalink
Merge pull request #74 from zoan37/window_dev
Browse files Browse the repository at this point in the history
Support Window AI
  • Loading branch information
joshsny authored Jun 8, 2023
2 parents 0eb4d6d + 7d46737 commit 6cc2d3f
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ logs/
start.sh

database.db
database.db-journal
vectors.pickle.gz
agents/*.txt
discord_server.out
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ Read through the dedicated [Discord setup docs](DISCORD.md)

## Using with Anthropic Claude
Make sure you have an `ANTHROPIC_API_KEY` in your env, then you can use `poetry run world --claude` which will run the world using `claude-v1` for some calls and `claude-v1-instant` for others.

## Using with Window
Make sure you have the [Window extension](https://windowai.io/) installed, then you can use `poetry run world --window`. Some models may be slow to respond, since the prompts are very long.

## Contributing

We enthusiastically welcome contributions to GPTeam! To contribute, please follow these steps:
Expand Down
16 changes: 16 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ aiosqlite = "^0.19.0"
hyperdb-python = "^0.1.3"
quart = "^0.18.4"
anthropic = "^0.2.10"
websocket-client = "^1.5.2"

[tool.black]
line-length = 88
Expand Down
1 change: 1 addition & 0 deletions src/utils/model_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class ChatModelName(Enum):
GPT4 = "gpt-4"
CLAUDE = "claude-v1"
CLAUDE_INSTANT = "claude-instant-v1"
WINDOW = "window"
3 changes: 3 additions & 0 deletions src/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain.chat_models.base import BaseChatModel
from langchain.llms import OpenAI
from langchain.schema import BaseMessage
from utils.windowai_model import ChatWindowAI

from .cache import chat_json_cache, json_cache
from .model_name import ChatModelName
Expand All @@ -28,6 +29,8 @@ def get_chat_model(name: ChatModelName, **kwargs) -> BaseChatModel:
return ChatAnthropic(model=name.value, **kwargs)
elif name == ChatModelName.CLAUDE_INSTANT:
return ChatAnthropic(model=name.value, **kwargs)
elif name == ChatModelName.WINDOW:
return ChatWindowAI(model_name=name.value, **kwargs)
else:
raise ValueError(f"Invalid model name: {name}")

Expand Down
14 changes: 12 additions & 2 deletions src/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,21 @@
ANNOUNCER_DISCORD_TOKEN = os.getenv("ANNOUNCER_DISCORD_TOKEN")

DEFAULT_SMART_MODEL = (
ChatModelName.TURBO if "--turbo" in sys.argv else ChatModelName.CLAUDE if "--claude" in sys.argv else ChatModelName.GPT4
ChatModelName.TURBO
if "--turbo" in sys.argv
else ChatModelName.CLAUDE
if "--claude" in sys.argv
else ChatModelName.WINDOW
if "--window" in sys.argv
else ChatModelName.GPT4
)

DEFAULT_FAST_MODEL = (
ChatModelName.CLAUDE_INSTANT if "--claude" in sys.argv else ChatModelName.TURBO
ChatModelName.CLAUDE_INSTANT
if "--claude" in sys.argv
else ChatModelName.WINDOW
if "--window" in sys.argv
else ChatModelName.TURBO
)


Expand Down
105 changes: 105 additions & 0 deletions src/utils/windowai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import langchain
from langchain.chat_models.base import BaseChatModel, SimpleChatModel
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
SystemMessage,
)
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
import websocket
import uuid
import json


class MessageDict(TypedDict):
role: str
content: str


class RequestDict(TypedDict):
messages: List[MessageDict]
temperature: float
request_id: str


class ResponseDict(TypedDict):
content: str
request_id: str


class ChatWindowAI(BaseChatModel):
model_name: str = "window"
"""Model name to use."""
temperature: float = 0
"""What sampling temperature to use."""
streaming: bool = False
"""Whether to stream the results."""
request_timeout: int = 3600
"""Timeout in seconds for the request."""

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "window-chat"

def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
output_str = self._call(messages, stop=stop)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
result = ChatResult(generations=[generation])
return result

async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
return self._generate(messages, stop=stop)

def _call(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> str:
request_id = str(uuid.uuid4())
request: RequestDict = {
"messages": [],
"temperature": self.temperature,
"request_id": request_id,
}

for message in messages:
role = "user" # default role is user
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"

request["messages"].append(
{
"role": role,
"content": message.content,
}
)

ws = websocket.WebSocket()
ws.connect("ws://127.0.0.1:5000/windowmodel")
ws.send(json.dumps(request))
message = ws.recv()
ws.close()

response: ResponseDict = json.loads(message)

response_content = response["content"]
response_request_id = response["request_id"]

# sanity check that response corresponds to request
if request_id != response_request_id:
raise ValueError(
f"Invalid request ID: {response_request_id}, expected: {request_id}"
)

return response_content
39 changes: 39 additions & 0 deletions src/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from dotenv import load_dotenv
from quart import Quart, abort, make_response, send_file, websocket
from ..utils.model_name import ChatModelName
from ..utils.parameters import DEFAULT_FAST_MODEL, DEFAULT_SMART_MODEL

from src.utils.database.base import Tables
from src.utils.database.client import get_database

load_dotenv()

window_request_queue = asyncio.Queue()
window_response_queue = asyncio.Queue()


def get_server():
app = Quart(__name__)
Expand Down Expand Up @@ -89,4 +94,38 @@ async def world_websocket():
{"agents": sorted_agents, "name": worlds[0]["name"]}
)

@app.websocket("/window")
async def window_websocket():
if (
DEFAULT_SMART_MODEL != ChatModelName.WINDOW
and DEFAULT_FAST_MODEL != ChatModelName.WINDOW
):
return

while True:
await asyncio.sleep(0.25)

request = await window_request_queue.get()
await websocket.send(request)

response = await websocket.receive()
await window_response_queue.put(response)

@app.websocket("/windowmodel")
async def window_model_websocket():
if (
DEFAULT_SMART_MODEL != ChatModelName.WINDOW
and DEFAULT_FAST_MODEL != ChatModelName.WINDOW
):
return

while True:
await asyncio.sleep(0.25)

request = await websocket.receive()
await window_request_queue.put(request)

response = await window_response_queue.get()
await websocket.send(response)

return app
36 changes: 36 additions & 0 deletions src/web/templates/logs.html
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,42 @@
};
}, []);

React.useEffect(() => {
const socket = new WebSocket('ws://' + window.location.host + '/window');


socket.onmessage = (e) => {
if (!window.ai) {
alert('window.ai not found. Please install at https://windowai.io/');
return;
}

const { request_id, messages, temperature } = JSON.parse(e.data);


window.ai.generateText(
{
messages: messages
},
{
temperature: temperature,
// enabling streaming prevents "timeout of 42000ms exceeded" and "status code 405" errors
onStreamResult: (res) => {}
}
).then(([response]) => {
const result = {
request_id: request_id,
content: response.message.content
};

socket.send(JSON.stringify(result));
});
};

return () => {
socket.close();
};
}, []);


return (
Expand Down

0 comments on commit 6cc2d3f

Please sign in to comment.