diff --git a/api/chat_api.py b/api/chat_api.py index a5f848a..5064beb 100644 --- a/api/chat_api.py +++ b/api/chat_api.py @@ -4,10 +4,9 @@ from fastapi import APIRouter from starlette.websockets import WebSocket, WebSocketDisconnect -from service.chat.chat_room_manager import chat_rooms, ChatRoom -from service.chat.message import Message -from service.chat.user_connection import UserConnection +from core.dependencies import chat_room_manager from dto.chat_room_response import ChatRoomResponse, ListChatRoomsResponse +from service.chat.user_connection import UserConnection router = APIRouter(prefix="/v1/chat") @@ -15,32 +14,32 @@ @router.post("/rooms/new") async def create_chat_room(): new_room_id = str(uuid.uuid4()) - chat_rooms[new_room_id] = ChatRoom(new_room_id) + chat_room = chat_room_manager.create_chat_room(room_id=new_room_id) # 채팅방 클렌징을 위해 일정 시간동안 입장한 사람이 없다면 채팅방 제거 async def check_and_clear_inactive_room(room_id): await asyncio.sleep(10) - if room_id in chat_rooms and chat_rooms[room_id].count_connections() == 0: - del chat_rooms[room_id] + if chat_room_manager.count_user_in_room(room_id=room_id) == 0: + chat_room_manager.delete_chat_room(room_id=room_id) asyncio.create_task(check_and_clear_inactive_room(new_room_id)) return ChatRoomResponse( room_id=new_room_id, - room_name=chat_rooms[new_room_id].room_name, + room_name=chat_room.room_name, user_count=0, ) @router.get("/rooms") -async def get_rooms(): +async def list_rooms(): return ListChatRoomsResponse( chat_rooms=[ ChatRoomResponse( - room_id=key, - room_name=value.room_name, - user_count=value.count_connections(), - ) for key, value in chat_rooms.items() + room_id=room.room_id, + room_name=room.room_name, + user_count=room.count_connections(), + ) for room in chat_room_manager.list_chat_rooms() ] ) @@ -54,17 +53,11 @@ async def connect_chat_room(websocket: WebSocket, room_id: str, username: str): ) try: - await chat_rooms[room_id].connect(connection) - await chat_rooms[room_id].broadcast_system_message(message=f'{username}가 방에 입장했습니다.') + await chat_room_manager.connect(room_id=room_id, connection=connection) while True: - data = await websocket.receive_text() - message = Message.parse_raw(data) - - await chat_rooms[room_id].broadcast(message) - connection.add_message(message) + message = await connection.receive_message() + await chat_room_manager.broadcast(room_id=room_id, message=message) except WebSocketDisconnect: - await chat_rooms[room_id].disconnect(connection) - if room_id in chat_rooms: - await chat_rooms[room_id].broadcast(Message(username="System", message=f"{username}가 방에서 나갔습니다.")) + await chat_room_manager.disconnect(room_id=room_id, connection=connection) diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..f20b825 --- /dev/null +++ b/config/config.py @@ -0,0 +1,2 @@ +# 현재 메모리에 유저당 메시지를 저장하고 있는데, 너무 많은 메모리를 차지하지 않도록 개수를 제한합니다. +MAX_MESSAGE_TO_SAVE = 30 diff --git a/core/dependencies.py b/core/dependencies.py index 21ccbe3..23bf4db 100644 --- a/core/dependencies.py +++ b/core/dependencies.py @@ -1,9 +1,9 @@ -from service.chat.connection_manager import ConnectionManager +from service.chat.chat_room_manager import ChatRoomManager from service.emotion_analysis.emotion_classifier import EmotionClassifier # DI 구조를 고민하다가 지금은 단순하게 여기에 의존성을 Singleton으로 선언해둡니다. -connection_manager = ConnectionManager() +chat_room_manager = ChatRoomManager() emotion_classifier = EmotionClassifier() # m1 import 이슈로 작업할 때는 MockEmotionClassifier 사용 diff --git a/main.py b/main.py index bd0666d..f8de543 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,8 @@ from starlette.templating import Jinja2Templates from api import chat_api -from service.chat.chat_room_manager import chat_rooms +from core.dependencies import emotion_classifier, chat_room_manager from service.chat.user_connection import UserConnection -from core.dependencies import emotion_classifier app = FastAPI() templates = Jinja2Templates(directory="templates") @@ -38,18 +37,14 @@ async def broadcast_emotion_message(): try: await asyncio.sleep(20) - for key in chat_rooms.keys(): - manager = chat_rooms[key].manager - - connections = manager.get_connections() - if not connections: + for room in chat_room_manager.list_chat_rooms(): + if room.count_connections() == 0: continue - chose_connection: UserConnection = random.choice(connections) - - messages = chose_connection.get_messages() + user_connection: UserConnection = random.choice(room.list_connections()) + messages = user_connection.list_messages() if len(messages) == 0: - await manager.broadcast_system_message( + await room.broadcast_system_message( message="메시지를 입력해보세요!" ) continue @@ -61,8 +56,8 @@ async def broadcast_emotion_message(): emotion_text = emotion_classifier.classify(combined_message) - await manager.broadcast_system_message( - message=f"{chose_connection.username}의 {emotion_text} 느껴집니다." + await room.broadcast_system_message( + message=f"{user_connection.username}의 {emotion_text} 느껴집니다." ) except Exception as e: diff --git a/service/chat/chat_room_manager.py b/service/chat/chat_room_manager.py index a3c4cf2..407d40d 100644 --- a/service/chat/chat_room_manager.py +++ b/service/chat/chat_room_manager.py @@ -1,12 +1,16 @@ import uuid -from typing import Dict +from typing import Dict, List from service.chat.connection_manager import ConnectionManager from service.chat.message import Message from service.chat.user_connection import UserConnection -# TODO: ChatRoomRepository 또는 ChatRoomManager를 만들어서 관리하도록 +class NotFoundChatRoomException(Exception): + def __init__(self, room_id): + self.message = f"Chat room {room_id} not found" + + class ChatRoom: def __init__(self, room_id): self.room_id = room_id @@ -19,9 +23,6 @@ async def connect(self, connection: UserConnection): async def disconnect(self, connection: UserConnection): self.manager.disconnect(connection) - if self.manager.count_connections() == 0: - del chat_rooms[self.room_id] # 사용자가 모두 나가면 방 삭제 - async def broadcast(self, message: Message): await self.manager.broadcast(message) @@ -31,5 +32,77 @@ async def broadcast_system_message(self, message: str): def count_connections(self): return self.manager.count_connections() + def list_connections(self): + return self.manager.get_connections() + + +class ChatRoomManager: + def __init__(self): + self.chat_room_by_id: Dict[str, ChatRoom] = {} + + def create_chat_room(self, room_id: str): + chat_room = ChatRoom(room_id) + self.chat_room_by_id[room_id] = chat_room + return chat_room + + def get_chat_room(self, room_id: str) -> ChatRoom | None: + return self.chat_room_by_id.get(room_id) + + def list_chat_rooms(self) -> List[ChatRoom]: + return list(self.chat_room_by_id.values()) + + def delete_chat_room(self, room_id: str): + del self.chat_room_by_id[room_id] + + async def connect(self, room_id: str, connection: UserConnection): + chat_room = self.get_chat_room(room_id) + if chat_room: + await chat_room.connect(connection) + await self.broadcast_system_message( + room_id=room_id, + message=f'{connection.username}가 방에 입장했습니다.', + ) + else: + raise NotFoundChatRoomException(room_id) + + async def disconnect(self, room_id: str, connection: UserConnection): + chat_room = self.get_chat_room(room_id) + if chat_room: + await chat_room.disconnect(connection) + await self.broadcast_system_message( + room_id=room_id, + message=f"{connection.username}가 방에서 나갔습니다.", + ) + + if chat_room.count_connections() == 0: + self.delete_chat_room(room_id) + else: + raise NotFoundChatRoomException(room_id) + + async def broadcast(self, room_id: str, message: Message): + chat_room = self.get_chat_room(room_id) + if chat_room: + await chat_room.broadcast(message) + else: + raise NotFoundChatRoomException(room_id) + + async def broadcast_system_message(self, room_id: str, message: str): + chat_room = self.get_chat_room(room_id) + if chat_room: + await chat_room.broadcast_system_message(message) + else: + raise NotFoundChatRoomException(room_id) + + def count_user_in_room(self, room_id: str): + chat_room = self.get_chat_room(room_id) + if chat_room: + return chat_room.count_connections() + else: + raise NotFoundChatRoomException(room_id) -chat_rooms: Dict[str, ChatRoom] = {} + def list_users_in_room(self, room_id: str): + chat_room = self.get_chat_room(room_id) + if chat_room: + return chat_room.manager.get_connections() + else: + raise NotFoundChatRoomException(room_id) diff --git a/service/chat/connection_manager.py b/service/chat/connection_manager.py index 37065b2..3dd43b5 100644 --- a/service/chat/connection_manager.py +++ b/service/chat/connection_manager.py @@ -14,18 +14,18 @@ def disconnect(self, connection: UserConnection): self.active_connections.remove(connection) async def broadcast_system_message(self, message: str): - system_message = Message( - username="System", - message=message, - message_type=MessageType.SYSTEM_MESSAGE, - ) - for connection in self.active_connections: - await connection.send_text(system_message.json()) + await connection.send_message( + Message( + username="System", + message=message, + message_type=MessageType.SYSTEM_MESSAGE, + ), + ) async def broadcast(self, message: Message): for connection in self.active_connections: - await connection.send_text(message.json()) + await connection.send_message(message) def get_connections(self): return self.active_connections diff --git a/service/chat/user_connection.py b/service/chat/user_connection.py index decb3bc..5ec891e 100644 --- a/service/chat/user_connection.py +++ b/service/chat/user_connection.py @@ -2,6 +2,7 @@ from starlette.websockets import WebSocket +from config.config import MAX_MESSAGE_TO_SAVE from service.chat.message import Message @@ -12,14 +13,25 @@ def __init__(self, user_id: str, username: str, websocket: WebSocket): self.websocket: WebSocket = websocket self.messages: List[Message] = [] - def accept(self): - return self.websocket.accept() + async def accept(self): + return await self.websocket.accept() - async def send_text(self, message: str): - return await self.websocket.send_text(message) + async def close(self): + return await self.websocket.close() - def add_message(self, message: Message): + async def send_message(self, message: Message): + await self.websocket.send_text(message.json()) + self.save_message(message) + + def save_message(self, message: Message): self.messages.append(message) + # 메모리 사용량을 줄이기 위해 지정된 메시지 수만 메모리에 저장 + if len(self.messages) > MAX_MESSAGE_TO_SAVE: + self.messages.pop(0) - def get_messages(self): + def list_messages(self) -> List[Message]: return self.messages + + async def receive_message(self) -> Message: + message = await self.websocket.receive_text() + return Message.parse_raw(message)