Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Save chats into file-based DB #66

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions backend/chatsky_ui/api/api_v1/endpoints/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
from typing import Any, Dict, List, Optional, Union

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, WebSocket, WebSocketException, status
from omegaconf import OmegaConf

from chatsky_ui.api import deps
from chatsky_ui.schemas.pagination import Pagination
from chatsky_ui.schemas.preset import Preset
from chatsky_ui.services.index import Index
from chatsky_ui.services.process_manager import BuildManager, ProcessManager, RunManager
from chatsky_ui.services.websocket_manager import WebSocketManager
from chatsky_ui.schemas.process_status import Status
from chatsky_ui.core.config import settings
from chatsky_ui.db.base import read_conf

router = APIRouter()

Expand Down Expand Up @@ -166,7 +170,12 @@ async def start_run(


@router.get("/run/stop/{run_id}", status_code=200)
async def stop_run(*, run_id: int, run_manager: RunManager = Depends(deps.get_run_manager)) -> Dict[str, str]:
async def stop_run(
*,
run_id: int,
run_manager: RunManager = Depends(deps.get_run_manager),
websocket_manager: WebSocketManager = Depends(deps.get_websocket_manager)
) -> Dict[str, str]:
"""Stops a `run` process with the given id.

Args:
Expand All @@ -179,7 +188,9 @@ async def stop_run(*, run_id: int, run_manager: RunManager = Depends(deps.get_ru
Returns:
{"status": "ok"}: in case of stopping a process successfully.
"""

if websocket_manager.is_connected(run_id):
run_manager.logger.info("Closing websocket connection")
await websocket_manager.close(run_id)
return await _stop_process(run_id, run_manager, process="run")


Expand Down Expand Up @@ -260,8 +271,11 @@ async def connect(
if run_id not in run_manager.processes:
run_manager.logger.error("process with run_id '%s' exited or never existed", run_id)
raise WebSocketException(code=status.WS_1014_BAD_GATEWAY)
if await run_manager.get_status(run_id) != Status.ALIVE:
run_manager.logger.error("process with run_id '%s' isn't Alive.", run_id)
raise WebSocketException(code=status.WS_1014_BAD_GATEWAY)

await websocket_manager.connect(websocket)
await websocket_manager.connect(run_id, websocket)
run_manager.logger.info("Websocket for run process '%s' has been opened", run_id)

await websocket.send_text("Start chatting")
Expand All @@ -278,4 +292,10 @@ async def connect(
[output_task, input_task],
return_when=asyncio.FIRST_COMPLETED,
)
websocket_manager.disconnect(websocket)


@router.get("/run/chats/read", response_model=dict[str, Union[str, list[dict]]], status_code=200)
async def read_chats():
omega_chats = await read_conf(settings.chats_path)
dict_chats = OmegaConf.to_container(omega_chats, resolve=True)
return {"status": "ok", "data": dict_chats} # type: ignore
1 change: 1 addition & 0 deletions backend/chatsky_ui/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def set_config(self, **kwargs):
def _set_user_proj_paths(self):
self.builds_path = self.work_directory / "chatsky_ui/app_data/builds.yaml"
self.runs_path = self.work_directory / "chatsky_ui/app_data/runs.yaml"
self.chats_path = self.work_directory / "chatsky_ui/app_data/chats.yaml"
self.frontend_flows_path = self.work_directory / "chatsky_ui/app_data/frontend_flows.yaml"
self.dir_logs = self.work_directory / "chatsky_ui/logs"
self.presets = self.work_directory / "chatsky_ui/presets"
Expand Down
7 changes: 7 additions & 0 deletions backend/chatsky_ui/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ async def read_conf(path: Path) -> Union[DictConfig, ListConfig]:
return omega_data


async def read_conf_as_obj(path: Path) -> Union[dict, list]:
"""Returns the configurations read as python objects."""
omega_data = await read_conf(path)
conf_dict = OmegaConf.to_container(omega_data, resolve=True)
return conf_dict # type: ignore


async def write_conf(data: Union[DictConfig, ListConfig, dict, list], path: Path) -> None:
yaml_conf = OmegaConf.to_yaml(data)
async with file_lock:
Expand Down
78 changes: 59 additions & 19 deletions backend/chatsky_ui/services/websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,24 @@
"""
import asyncio
from asyncio.tasks import Task
from typing import Dict, List, Set
from typing import Dict, Set
from uuid import uuid4

from fastapi import WebSocket, WebSocketDisconnect
from datetime import datetime

from chatsky_ui.core.logger_config import get_logger
from chatsky_ui.services.process_manager import ProcessManager
from chatsky_ui.db.base import write_conf, read_conf_as_obj
from chatsky_ui.core.config import settings


class WebSocketManager:
"""Controls websocket operations connect, disconnect, check status, and communicate."""

def __init__(self):
self.pending_tasks: Dict[WebSocket, Set[Task]] = dict()
self.active_connections: List[WebSocket] = []
self.active_connections: Dict[int, dict] = {}
self._logger = None

@property
Expand All @@ -28,24 +32,45 @@ def logger(self):
def set_logger(self):
self._logger = get_logger(__name__)

async def connect(self, websocket: WebSocket):
async def connect(self, run_id: int, websocket: WebSocket):
"""Accepts the websocket connection and marks it as active connection."""
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
"""Cancels pending tasks of the open websocket process and removes it from active connections."""
# TODO: await websocket.close()
ws_id = uuid4().hex
self.active_connections[run_id] = {
"websocket": websocket,
"chat": {"id": ws_id, "timestamp": datetime.now().strftime("%Y-%m-%dT%H:%M:%S"), "messages": []},
}

async def close(self, run_id: int):
"""Closes an active websocket connection."""
websocket = self.active_connections[run_id]["websocket"]
await websocket.close()

async def disconnect(
self, run_id: int, websocket: WebSocket
): # no need to pass websocket. use active_connections[run_id]
"""Executes cleanup.

- Writes the chat info to DB.
- Cancels pending tasks of the open websocket process.
- Removes the websocket from active connections."""
dict_chats = await read_conf_as_obj(settings.chats_path)
dict_chats = dict_chats or []
dict_chats.append(self.active_connections[run_id]["chat"]) # type: ignore
await write_conf(dict_chats, settings.chats_path)
self.logger.info("Chats info were written to DB")

if websocket in self.pending_tasks:
self.logger.info("Cancelling pending tasks")
for task in self.pending_tasks[websocket]:
task.cancel()
del self.pending_tasks[websocket]
self.active_connections.remove(websocket)
del self.active_connections[run_id]

def check_status(self, websocket: WebSocket):
if websocket in self.active_connections:
return websocket # return Status!
def is_connected(self, run_id: int):
"""Returns True if the run_id is connected to a websocket, False otherwise."""
return run_id in self.active_connections

async def send_process_output_to_websocket(
self, run_id: int, process_manager: ProcessManager, websocket: WebSocket
Expand All @@ -56,11 +81,19 @@ async def send_process_output_to_websocket(
response = await process_manager.processes[run_id].read_stdout()
if not response:
break
await websocket.send_text(response.decode().strip())
text = response.decode().strip()
await websocket.send_text(text)
self.active_connections[run_id]["chat"]["messages"].append(text)
except WebSocketDisconnect:
self.logger.info("Websocket connection is closed by client")
except RuntimeError:
raise
self.logger.info("Websocket connection is closed")
await self.disconnect(run_id, websocket)
except RuntimeError as e:
if "Unexpected ASGI message 'websocket.send'" in str(
e
) or "Cannot call 'send' once a close message has been sent" in str(e):
self.logger.info("Websocket connection was forced to close.")
else:
raise e

async def forward_websocket_messages_to_process(
self, run_id: int, process_manager: ProcessManager, websocket: WebSocket
Expand All @@ -72,9 +105,16 @@ async def forward_websocket_messages_to_process(
if not user_message:
break
await process_manager.processes[run_id].write_stdin(user_message.encode() + b"\n")
self.active_connections[run_id]["chat"]["messages"].append(user_message)
except asyncio.CancelledError:
self.logger.info("Websocket connection is closed")
self.logger.info("Websocket connection is cancelled")
except WebSocketDisconnect:
self.logger.info("Websocket connection is closed by client")
except RuntimeError:
raise
self.logger.info("Websocket connection is closed")
await self.disconnect(run_id, websocket)
except RuntimeError as e:
if "Unexpected ASGI message 'websocket.send'" in str(
e
) or "Cannot call 'send' once a close message has been sent" in str(e):
self.logger.info("Websocket connection was forced to close.")
else:
raise e
10 changes: 5 additions & 5 deletions backend/chatsky_ui/tests/api/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,15 @@ async def test_get_run_logs(mocker, pagination):
async def test_connect(mocker):
websocket = mocker.AsyncMock()
websocket_manager = mocker.AsyncMock()
websocket_manager.disconnect = mocker.MagicMock()
run_manager = mocker.MagicMock()
run_process = mocker.MagicMock()
websocket_manager.disconnect = mocker.AsyncMock()
run_manager = mocker.AsyncMock()
run_process = mocker.AsyncMock()
run_manager.processes = {RUN_ID: run_process}
run_manager.get_status = mocker.AsyncMock(return_value=Status.ALIVE)
mocker.patch.object(websocket, "query_params", {"run_id": str(RUN_ID)})

await connect(websocket, websocket_manager, run_manager)

websocket_manager.connect.assert_awaited_once_with(websocket)
websocket_manager.connect.assert_awaited_once_with(RUN_ID, websocket)
websocket_manager.send_process_output_to_websocket.assert_awaited_once_with(RUN_ID, run_manager, websocket)
websocket_manager.forward_websocket_messages_to_process.assert_awaited_once_with(RUN_ID, run_manager, websocket)
websocket_manager.disconnect.assert_called_once_with(websocket)
32 changes: 20 additions & 12 deletions backend/chatsky_ui/tests/services/test_websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,64 @@
from fastapi import WebSocket


RUN_ID = 42


class TestWebSocketManager:
@pytest.mark.asyncio
async def test_connect(self, mocker, websocket_manager):
mocked_websocket = mocker.MagicMock(spec=WebSocket)
mocked_websocket = mocker.MagicMock()
mocked_websocket.accept = mocker.AsyncMock()

await websocket_manager.connect(mocked_websocket)
await websocket_manager.connect(RUN_ID, mocked_websocket)

mocked_websocket.accept.assert_awaited_once_with()
assert mocked_websocket in websocket_manager.active_connections
assert mocked_websocket == websocket_manager.active_connections[RUN_ID]["websocket"]

@pytest.mark.asyncio
async def test_disconnect(self, mocker, websocket_manager):
mocked_websocket = mocker.MagicMock(spec=WebSocket)
websocket_manager.active_connections.append(mocked_websocket)
websocket_manager.active_connections[RUN_ID] = {"websocket": mocked_websocket, "chat": {}}
websocket_manager.pending_tasks[mocked_websocket] = set()

websocket_manager.disconnect(mocked_websocket)
await websocket_manager.disconnect(RUN_ID, mocked_websocket)

assert mocked_websocket not in websocket_manager.pending_tasks
assert mocked_websocket not in websocket_manager.active_connections
assert RUN_ID not in websocket_manager.active_connections

@pytest.mark.asyncio
async def test_send_process_output_to_websocket(self, mocker, websocket_manager):
run_id = 42
awaited_response = "Hello from DF-Designer"

mocked_websocket = mocker.MagicMock(spec=WebSocket)
websocket_manager.active_connections[RUN_ID] = {"websocket": mocked_websocket, "chat": {"messages": []}}

websocket = mocker.AsyncMock()
run_manager = mocker.MagicMock()
run_process = mocker.MagicMock()
run_process.read_stdout = mocker.AsyncMock(side_effect=[awaited_response.encode(), None])
run_manager.processes = {run_id: run_process}
run_manager.processes = {RUN_ID: run_process}

await websocket_manager.send_process_output_to_websocket(run_id, run_manager, websocket)
await websocket_manager.send_process_output_to_websocket(RUN_ID, run_manager, websocket)

assert run_process.read_stdout.call_count == 2
websocket.send_text.assert_awaited_once_with(awaited_response)

@pytest.mark.asyncio
async def test_forward_websocket_messages_to_process(self, mocker, websocket_manager):
run_id = 42
awaited_message = "Hello from DF-Designer"

mocked_websocket = mocker.MagicMock(spec=WebSocket)
websocket_manager.active_connections[RUN_ID] = {"websocket": mocked_websocket, "chat": {"messages": []}}

websocket = mocker.AsyncMock()
websocket.receive_text = mocker.AsyncMock(side_effect=[awaited_message, None])
run_manager = mocker.MagicMock()
run_process = mocker.MagicMock()
run_process.write_stdin = mocker.AsyncMock()
run_manager.processes = {run_id: run_process}
run_manager.processes = {RUN_ID: run_process}

await websocket_manager.forward_websocket_messages_to_process(run_id, run_manager, websocket)
await websocket_manager.forward_websocket_messages_to_process(RUN_ID, run_manager, websocket)

assert websocket.receive_text.await_count == 2
run_process.write_stdin.assert_called_once_with(awaited_message.encode() + b"\n")
Loading