Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ChatRoomManager로 채팅방 관리 로직 정리 #17

Merged
merged 6 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions api/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,42 @@
from fastapi import APIRouter
from starlette.websockets import WebSocket, WebSocketDisconnect

from service.chat.chat_room_manager import chat_rooms, ChatRoom
from service.chat.message import Message
from service.chat.user_connection import UserConnection
from core.dependencies import chat_room_manager
from dto.chat_room_response import ChatRoomResponse, ListChatRoomsResponse
from service.chat.user_connection import UserConnection

router = APIRouter(prefix="/v1/chat")


@router.post("/rooms/new")
async def create_chat_room():
new_room_id = str(uuid.uuid4())
chat_rooms[new_room_id] = ChatRoom(new_room_id)
chat_room = chat_room_manager.create_chat_room(room_id=new_room_id)

# 채팅방 클렌징을 위해 일정 시간동안 입장한 사람이 없다면 채팅방 제거
async def check_and_clear_inactive_room(room_id):
await asyncio.sleep(10)
if room_id in chat_rooms and chat_rooms[room_id].count_connections() == 0:
del chat_rooms[room_id]
if chat_room_manager.count_user_in_room(room_id=room_id) == 0:
chat_room_manager.delete_chat_room(room_id=room_id)

asyncio.create_task(check_and_clear_inactive_room(new_room_id))

return ChatRoomResponse(
room_id=new_room_id,
room_name=chat_rooms[new_room_id].room_name,
room_name=chat_room.room_name,
user_count=0,
)


@router.get("/rooms")
async def get_rooms():
async def list_rooms():
return ListChatRoomsResponse(
chat_rooms=[
ChatRoomResponse(
room_id=key,
room_name=value.room_name,
user_count=value.count_connections(),
) for key, value in chat_rooms.items()
room_id=room.room_id,
room_name=room.room_name,
user_count=room.count_connections(),
) for room in chat_room_manager.list_chat_rooms()
]
)

Expand All @@ -54,17 +53,11 @@ async def connect_chat_room(websocket: WebSocket, room_id: str, username: str):
)

try:
await chat_rooms[room_id].connect(connection)
await chat_rooms[room_id].broadcast_system_message(message=f'{username}가 방에 입장했습니다.')
await chat_room_manager.connect(room_id=room_id, connection=connection)

while True:
data = await websocket.receive_text()
message = Message.parse_raw(data)

await chat_rooms[room_id].broadcast(message)
connection.add_message(message)
message = await connection.receive_message()
await chat_room_manager.broadcast(room_id=room_id, message=message)

except WebSocketDisconnect:
await chat_rooms[room_id].disconnect(connection)
if room_id in chat_rooms:
await chat_rooms[room_id].broadcast(Message(username="System", message=f"{username}가 방에서 나갔습니다."))
await chat_room_manager.disconnect(room_id=room_id, connection=connection)
Empty file added config/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# 현재 메모리에 유저당 메시지를 저장하고 있는데, 너무 많은 메모리를 차지하지 않도록 개수를 제한합니다.
MAX_MESSAGE_TO_SAVE = 30
4 changes: 2 additions & 2 deletions core/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from service.chat.connection_manager import ConnectionManager
from service.chat.chat_room_manager import ChatRoomManager
from service.emotion_analysis.emotion_classifier import EmotionClassifier

# DI 구조를 고민하다가 지금은 단순하게 여기에 의존성을 Singleton으로 선언해둡니다.

connection_manager = ConnectionManager()
chat_room_manager = ChatRoomManager()

emotion_classifier = EmotionClassifier()
# m1 import 이슈로 작업할 때는 MockEmotionClassifier 사용
Expand Down
21 changes: 8 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from starlette.templating import Jinja2Templates

from api import chat_api
from service.chat.chat_room_manager import chat_rooms
from core.dependencies import emotion_classifier, chat_room_manager
from service.chat.user_connection import UserConnection
from core.dependencies import emotion_classifier

app = FastAPI()
templates = Jinja2Templates(directory="templates")
Expand Down Expand Up @@ -38,18 +37,14 @@ async def broadcast_emotion_message():
try:
await asyncio.sleep(20)

for key in chat_rooms.keys():
manager = chat_rooms[key].manager

connections = manager.get_connections()
if not connections:
for room in chat_room_manager.list_chat_rooms():
if room.count_connections() == 0:
continue

chose_connection: UserConnection = random.choice(connections)

messages = chose_connection.get_messages()
user_connection: UserConnection = random.choice(room.list_connections())
messages = user_connection.list_messages()
if len(messages) == 0:
await manager.broadcast_system_message(
await room.broadcast_system_message(
message="메시지를 입력해보세요!"
)
continue
Expand All @@ -61,8 +56,8 @@ async def broadcast_emotion_message():

emotion_text = emotion_classifier.classify(combined_message)

await manager.broadcast_system_message(
message=f"{chose_connection.username}의 {emotion_text} 느껴집니다."
await room.broadcast_system_message(
message=f"{user_connection.username}의 {emotion_text} 느껴집니다."
)

except Exception as e:
Expand Down
85 changes: 79 additions & 6 deletions service/chat/chat_room_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import uuid
from typing import Dict
from typing import Dict, List

from service.chat.connection_manager import ConnectionManager
from service.chat.message import Message
from service.chat.user_connection import UserConnection


# TODO: ChatRoomRepository 또는 ChatRoomManager를 만들어서 관리하도록
class NotFoundChatRoomException(Exception):
def __init__(self, room_id):
self.message = f"Chat room {room_id} not found"


class ChatRoom:
def __init__(self, room_id):
self.room_id = room_id
Expand All @@ -19,9 +23,6 @@ async def connect(self, connection: UserConnection):
async def disconnect(self, connection: UserConnection):
self.manager.disconnect(connection)

if self.manager.count_connections() == 0:
del chat_rooms[self.room_id] # 사용자가 모두 나가면 방 삭제

async def broadcast(self, message: Message):
await self.manager.broadcast(message)

Expand All @@ -31,5 +32,77 @@ async def broadcast_system_message(self, message: str):
def count_connections(self):
return self.manager.count_connections()

def list_connections(self):
return self.manager.get_connections()


class ChatRoomManager:
def __init__(self):
self.chat_room_by_id: Dict[str, ChatRoom] = {}

def create_chat_room(self, room_id: str):
chat_room = ChatRoom(room_id)
self.chat_room_by_id[room_id] = chat_room
return chat_room

def get_chat_room(self, room_id: str) -> ChatRoom | None:
return self.chat_room_by_id.get(room_id)

def list_chat_rooms(self) -> List[ChatRoom]:
return list(self.chat_room_by_id.values())

def delete_chat_room(self, room_id: str):
del self.chat_room_by_id[room_id]

async def connect(self, room_id: str, connection: UserConnection):
chat_room = self.get_chat_room(room_id)
if chat_room:
await chat_room.connect(connection)
await self.broadcast_system_message(
room_id=room_id,
message=f'{connection.username}가 방에 입장했습니다.',
)
else:
raise NotFoundChatRoomException(room_id)

async def disconnect(self, room_id: str, connection: UserConnection):
chat_room = self.get_chat_room(room_id)
if chat_room:
await chat_room.disconnect(connection)
await self.broadcast_system_message(
room_id=room_id,
message=f"{connection.username}가 방에서 나갔습니다.",
)

if chat_room.count_connections() == 0:
self.delete_chat_room(room_id)
else:
raise NotFoundChatRoomException(room_id)

async def broadcast(self, room_id: str, message: Message):
chat_room = self.get_chat_room(room_id)
if chat_room:
await chat_room.broadcast(message)
else:
raise NotFoundChatRoomException(room_id)

async def broadcast_system_message(self, room_id: str, message: str):
chat_room = self.get_chat_room(room_id)
if chat_room:
await chat_room.broadcast_system_message(message)
else:
raise NotFoundChatRoomException(room_id)

def count_user_in_room(self, room_id: str):
chat_room = self.get_chat_room(room_id)
if chat_room:
return chat_room.count_connections()
else:
raise NotFoundChatRoomException(room_id)

chat_rooms: Dict[str, ChatRoom] = {}
def list_users_in_room(self, room_id: str):
chat_room = self.get_chat_room(room_id)
if chat_room:
return chat_room.manager.get_connections()
else:
raise NotFoundChatRoomException(room_id)
16 changes: 8 additions & 8 deletions service/chat/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ def disconnect(self, connection: UserConnection):
self.active_connections.remove(connection)

async def broadcast_system_message(self, message: str):
system_message = Message(
username="System",
message=message,
message_type=MessageType.SYSTEM_MESSAGE,
)

for connection in self.active_connections:
await connection.send_text(system_message.json())
await connection.send_message(
Message(
username="System",
message=message,
message_type=MessageType.SYSTEM_MESSAGE,
),
)

async def broadcast(self, message: Message):
for connection in self.active_connections:
await connection.send_text(message.json())
await connection.send_message(message)

def get_connections(self):
return self.active_connections
Expand Down
24 changes: 18 additions & 6 deletions service/chat/user_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from starlette.websockets import WebSocket

from config.config import MAX_MESSAGE_TO_SAVE
from service.chat.message import Message


Expand All @@ -12,14 +13,25 @@ def __init__(self, user_id: str, username: str, websocket: WebSocket):
self.websocket: WebSocket = websocket
self.messages: List[Message] = []

def accept(self):
return self.websocket.accept()
async def accept(self):
return await self.websocket.accept()

async def send_text(self, message: str):
return await self.websocket.send_text(message)
async def close(self):
return await self.websocket.close()

def add_message(self, message: Message):
async def send_message(self, message: Message):
await self.websocket.send_text(message.json())
self.save_message(message)

def save_message(self, message: Message):
self.messages.append(message)
# 메모리 사용량을 줄이기 위해 지정된 메시지 수만 메모리에 저장
if len(self.messages) > MAX_MESSAGE_TO_SAVE:
self.messages.pop(0)

def get_messages(self):
def list_messages(self) -> List[Message]:
return self.messages

async def receive_message(self) -> Message:
message = await self.websocket.receive_text()
return Message.parse_raw(message)
Loading