From fd23d6de16d75268c81d38b1c379268824874d7b Mon Sep 17 00:00:00 2001 From: Victor Wildner Date: Sat, 4 Nov 2023 04:07:32 -0300 Subject: [PATCH] Fix: Chat use JWT validation (#87) * fix: update seed db init * feat: add token validation * feat: validate user when connecting to chat * fix: update spec * fix: remove deprecated chat gateway * feat: add connected user validation --- .devcontainer/postCreateCommand.sh | 2 +- backend/src/chat/chat.gateway.spec.ts | 64 ------------- backend/src/chat/chat.gateway.ts | 96 +++++++++++-------- backend/src/chat/chat.module.ts | 11 ++- backend/src/chat/dto/index.ts | 1 + backend/src/chat/dto/token.dto.ts | 5 + frontend/src/components/Chat/ChannelCard.tsx | 6 +- .../Chat/ChatUsersChannelPopOver.tsx | 9 +- frontend/src/components/Chat/OpenChannel.tsx | 4 +- frontend/src/contexts/ChatContext.tsx | 11 +++ frontend/src/services/chatClient.ts | 9 +- 11 files changed, 101 insertions(+), 117 deletions(-) delete mode 100644 backend/src/chat/chat.gateway.spec.ts create mode 100644 backend/src/chat/dto/token.dto.ts diff --git a/.devcontainer/postCreateCommand.sh b/.devcontainer/postCreateCommand.sh index b61a5115..7f0c8ebd 100755 --- a/.devcontainer/postCreateCommand.sh +++ b/.devcontainer/postCreateCommand.sh @@ -10,7 +10,7 @@ npm install -g @nestjs/cli npm run prisma:generate # Generate Prisma Migrations -npx prisma migrate dev --name init +npx prisma db push --force-reset && npx prisma db seed # Docker daemon setup sudo chown node:node /var/run/docker.sock diff --git a/backend/src/chat/chat.gateway.spec.ts b/backend/src/chat/chat.gateway.spec.ts deleted file mode 100644 index c778538c..00000000 --- a/backend/src/chat/chat.gateway.spec.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { Test, TestingModule } from '@nestjs/testing'; -import { ChatGateway } from './chat.gateway'; -import { ChatService } from './chat.service'; -import { ChatServiceMock } from './chat.service.mock'; -import { Server } from 'socket.io'; -import { SocketIOClientMock } from './socket.io-mock-client'; -import { ChatDto } from './dto'; -describe('ChatGateway', () => { - let server; - let client; - const chatId = 1; - beforeAll(async () => { - // Start server - server = new Server(); - server.listen(3333); - - // Connect client - client = new SocketIOClientMock(chatId); - }); - - afterAll(async () => { - // Close server - server.close(); - }); - - let chatGateway: ChatGateway; - let chatService: ChatService; - - beforeEach(async () => { - const module: TestingModule = await Test.createTestingModule({ - providers: [ - ChatGateway, - { - provide: ChatService, - useClass: ChatServiceMock, - }, - ], - }).compile(); - - chatGateway = module.get(ChatGateway); - chatService = module.get(ChatService); - }); - - it('should be defined', () => { - expect(chatGateway).toBeDefined(); - }); - - it('should join a chat successfully', async () => { - const login = 'testUser'; - const chatDto: ChatDto = { - chatId: 1, - password: 'password', - }; - - const getChatByIdSpy = jest.spyOn(chatService, 'getChatById'); - const addUserToChatSpy = jest.spyOn(chatService, 'addUserToChat'); - const joinChatSpy = jest.spyOn(chatGateway, 'joinChat'); - await chatGateway.joinChat(login, chatDto, client); - - expect(getChatByIdSpy).toBeCalledWith(chatId); - expect(addUserToChatSpy).toBeCalledWith(login, chatId); - expect(joinChatSpy).toBeCalledWith(login, chatDto, client); - }); -}); diff --git a/backend/src/chat/chat.gateway.ts b/backend/src/chat/chat.gateway.ts index 697372fd..2aca8035 100644 --- a/backend/src/chat/chat.gateway.ts +++ b/backend/src/chat/chat.gateway.ts @@ -7,42 +7,36 @@ import { SubscribeMessage, WebSocketGateway, WebSocketServer, + WsException, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; import { ChatService } from './chat.service'; +import { ParseIntPipe } from '@nestjs/common'; import { - ExecutionContext, - ParseIntPipe, - createParamDecorator, -} from '@nestjs/common'; -import { ChatDto, ChatMessageDto, NewChatDto, InviteChatDto } from './dto'; + ChatDto, + ChatMessageDto, + NewChatDto, + InviteChatDto, + TokenPayload, +} from './dto'; import * as argon2 from 'argon2'; +import { JwtService } from '@nestjs/jwt'; +import { UsersService } from 'src/users/users.service'; interface ConnectedUsers { [key: number]: Socket; } -export const SocketUser = createParamDecorator( - (data: string, ctx: ExecutionContext) => { - // eslint-disable-line @typescript-eslint/no-unused-vars - // const client = ctx.switchToWs().getClient(); - // const user = client.handshake.auth?.user; - - // return data ? user?.[data] : user; - const user = { - login: 'caio', - }; - - return user[data]; - }, -); - @WebSocketGateway({ namespace: 'chat' }) export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { - constructor(private chatService: ChatService) {} + constructor( + private chatService: ChatService, + private jwtService: JwtService, + private usersService: UsersService, + ) {} private connectedUsers: ConnectedUsers = {}; @WebSocketServer() @@ -60,11 +54,12 @@ export class ChatGateway @SubscribeMessage('message') async createMessage( - @SocketUser('login') login: any, @MessageBody() messageDto: ChatMessageDto, @ConnectedSocket() client: Socket, ) { const { chatId, content } = messageDto; + const login = client.handshake.auth?.user?.login; + client.emit('userLogin', client.handshake.auth?.user); const member = await this.chatService.getMemberFromChat(chatId, login); if (!member) { client.emit('error', { error: 'You are not a member of this chat' }); @@ -119,11 +114,11 @@ export class ChatGateway @SubscribeMessage('createChat') async createChat( - @SocketUser('login') login: string, @MessageBody() chatDto: NewChatDto, @ConnectedSocket() client: Socket, ) { const { chatName, chatType, password } = chatDto; + const login = client.handshake.auth?.user?.login; if (chatType === 'PUBLIC' && password) { client.emit('error', { error: 'Public chat cannot have password' }); return; @@ -157,11 +152,11 @@ export class ChatGateway @SubscribeMessage('createPrivateChat') async createPrivateChat( - @SocketUser('login') login: string, @MessageBody() privateChat: InviteChatDto, @ConnectedSocket() client: Socket, ) { const { guestList } = privateChat; + const login = client.handshake.auth?.user?.login; const createdChat = await this.chatService.createPrivateChat( login, guestList, @@ -190,7 +185,6 @@ export class ChatGateway // You already have to exist in this chat but cannot join the socket @SubscribeMessage('joinChat') async joinChat( - @SocketUser('login') login: string, @MessageBody() chatDto: ChatDto, @ConnectedSocket() client: Socket, ) { @@ -215,6 +209,7 @@ export class ChatGateway return; } } + const login = client.handshake.auth?.user?.login; const addedUser = await this.chatService.addUserToChat(login, chatId); if (!addedUser) { client.emit('joinChat', { message: 'User is already in chat', chat }); @@ -227,10 +222,10 @@ export class ChatGateway @SubscribeMessage('leaveChat') async leaveChat( - @SocketUser('login') login: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @ConnectedSocket() client: Socket, ) { + const login = client.handshake.auth?.user?.login; const you = await this.chatService.getMemberFromChat(chatId, login); await this.chatService.removeUserFromChat(login, chatId); client.leave(`chat:${chatId}`); @@ -290,7 +285,6 @@ export class ChatGateway // WARNING: This method should not be invoked by the client @SubscribeMessage('deleteChat') async deleteChat( - @SocketUser('login') login: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @ConnectedSocket() client: Socket, ) { @@ -300,6 +294,7 @@ export class ChatGateway client.emit('error', { error: 'Chat not found' }); return; } + const login = client.handshake.auth?.user?.login; const member = await this.chatService.getMemberFromChat(chatId, login); if (!member || member.role === 'MEMBER') { client.emit('error', { @@ -331,11 +326,11 @@ export class ChatGateway // TODO: Drop this rule and replace it by an invite event @SubscribeMessage('addToChat') async addToChat( - @SocketUser('login') login: string, @MessageBody() inviteChat: InviteChatDto, @ConnectedSocket() client: Socket, ) { const { chatId, guestList } = inviteChat; + const login = client.handshake.auth?.user?.login; const updatedChat = await this.chatService.addUsersToChat( chatId, guestList, @@ -360,11 +355,11 @@ export class ChatGateway @SubscribeMessage('giveAdmin') async giveAdmin( - @SocketUser('login') login: string, @MessageBody() users: InviteChatDto, @ConnectedSocket() client: Socket, ) { const { chatId, guestList } = users; + const login = client.handshake.auth?.user?.login; for (const user of guestList) { if (await this.notValidAction('giveAdmin', chatId, login, user, client)) { return; @@ -384,11 +379,11 @@ export class ChatGateway @SubscribeMessage('kickMember') async kickMember( - @SocketUser('login') login: string, @MessageBody('user') user: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @ConnectedSocket() client: Socket, ) { + const login = client.handshake.auth?.user?.login; if (await this.notValidAction('kickMember', chatId, login, user, client)) { return; } @@ -434,11 +429,11 @@ export class ChatGateway @SubscribeMessage('banMember') async banMember( - @SocketUser('login') login: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @MessageBody('user') user: string, @ConnectedSocket() client: Socket, ) { + const login = client.handshake.auth?.user?.login; if (await this.notValidAction('banMember', chatId, login, user, client)) { return; } @@ -462,11 +457,11 @@ export class ChatGateway @SubscribeMessage('muteMember') async muteMember( - @SocketUser('login') login: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @MessageBody('user') user: string, @ConnectedSocket() client: Socket, ) { + const login = client.handshake.auth?.user?.login; if (await this.notValidAction('muteMember', chatId, login, user, client)) { return; } @@ -483,11 +478,11 @@ export class ChatGateway @SubscribeMessage('unmuteMember') async unmuteMember( - @SocketUser('login') login: string, @MessageBody('chatId', new ParseIntPipe()) chatId: number, @MessageBody('user') user: string, @ConnectedSocket() client: Socket, ) { + const login = client.handshake.auth?.user?.login; if ( await this.notValidAction('unmuteMember', chatId, login, user, client) ) { @@ -513,7 +508,9 @@ export class ChatGateway const chat = await this.chatService.verifyChatPassword(chatId, password); if (!chat) { - return client.emit('verifyPassword', { error: 'Error handling the request' }); + return client.emit('verifyPassword', { + error: 'Error handling the request', + }); } return client.emit('verifyPassword', { message: 'Password is correct' }); } @@ -536,11 +533,11 @@ export class ChatGateway } // TODO: async handleConnection(@ConnectedSocket() client: Socket) { - // const login = client.handshake.auth?.user?.login; + const login = client.handshake.auth?.user?.login; // TODO: remove this hardcoded user id - const login = 'caio'; if (!login) { client.emit('connected', { error: 'User not found' }); + client.disconnect(); return; } this.connectedUsers[login] = client; @@ -556,7 +553,30 @@ export class ChatGateway } } - afterInit(server: any) { - // ... + afterInit(_: Server) { + this.server.use((socket, next) => { + this.validateConnection(socket) + .then((user) => { + socket.handshake.auth['user'] = user; + console.log(`User ${socket.handshake.auth['user'].login} connected`); + socket.emit('userLogin', user); + next(); + }) + .catch((err) => { + return next(new Error(err)); + }); + }); + } + + private validateConnection(client: Socket) { + const token = client.handshake.headers.cookie.split(';')[0].split('=')[1]; + try { + const payload = this.jwtService.verify(token, { + secret: process.env.JWT_SECRET, + }); + return this.usersService.findOne(payload.sub); + } catch { + throw new WsException('Token invalid or expired'); + } } } diff --git a/backend/src/chat/chat.module.ts b/backend/src/chat/chat.module.ts index 8778848c..4425ade0 100644 --- a/backend/src/chat/chat.module.ts +++ b/backend/src/chat/chat.module.ts @@ -4,10 +4,17 @@ import { PrismaModule } from '../prisma/prisma.module'; import { PrismaService } from '../prisma/prisma.service'; import { ChatGateway } from './chat.gateway'; import { ChatController } from './chat.controller'; - +import { JwtService } from '@nestjs/jwt'; +import { UsersService } from 'src/users/users.service'; @Module({ imports: [PrismaModule], controllers: [ChatController], - providers: [ChatService, PrismaService, ChatGateway], + providers: [ + ChatService, + PrismaService, + ChatGateway, + JwtService, + UsersService, + ], }) export class ChatModule {} diff --git a/backend/src/chat/dto/index.ts b/backend/src/chat/dto/index.ts index 5bcf9240..f12e5794 100644 --- a/backend/src/chat/dto/index.ts +++ b/backend/src/chat/dto/index.ts @@ -1 +1,2 @@ export * from './chat.dto'; +export * from './token.dto'; diff --git a/backend/src/chat/dto/token.dto.ts b/backend/src/chat/dto/token.dto.ts new file mode 100644 index 00000000..eccb68b3 --- /dev/null +++ b/backend/src/chat/dto/token.dto.ts @@ -0,0 +1,5 @@ +export interface TokenPayload { + sub: string; // user id + mfaEnabled: boolean; + mfaAuthenticated: boolean; +} diff --git a/frontend/src/components/Chat/ChannelCard.tsx b/frontend/src/components/Chat/ChannelCard.tsx index 73cdc47d..eed8e213 100644 --- a/frontend/src/components/Chat/ChannelCard.tsx +++ b/frontend/src/components/Chat/ChannelCard.tsx @@ -8,7 +8,7 @@ type ChannelCardProps = { }; export default function ChannelCard({ chat }: ChannelCardProps) { - const { setShowElement, setSelectedChat } = useContext(ChatContext); + const { setShowElement, setSelectedChat, user } = useContext(ChatContext); const handleDeleteChannel = () => { chatService.socket?.emit("deleteChat", { chatId: chat.id }); }; @@ -30,11 +30,11 @@ export default function ChannelCard({ chat }: ChannelCardProps) {
{chat.chatType === 'PROTECTED' && } - {chat.owner === 'caio' && } + {chat.owner === user.login && }
- {chat.owner === 'caio' && ( + {chat.owner === user.login && ( user.userLogin !== currentUser.login); return ( @@ -47,7 +52,7 @@ export default function ChatUsersChannelPopOver({ {...attributes.popper} >
- {users.map((user) => ( + {otherUsers.map((user) => (
([]); @@ -149,7 +149,7 @@ export function OpenChannel() {
void; setValidationRequired: React.Dispatch>; validationRequired: boolean; + user: User; }; type ChatProviderProps = { @@ -37,6 +38,10 @@ export type Chat = { owner: string; }; +export type User = { + login: string; +}; + export type ChatList = Chat[]; export const ChatContext = createContext( @@ -48,6 +53,7 @@ export const ChatProvider = ({ children }: ChatProviderProps) => { const [showElement, setShowElement] = useState("showChannels"); const [selectedChat, setSelectedChat] = useState({} as Chat); const [chatList, setChatList] = useState([]); + const [user, setUser] = useState({} as User); const [isLoading, setIsLoading] = useState(true); const [validationRequired, setValidationRequired] = useState(true); @@ -68,6 +74,10 @@ export const ChatProvider = ({ children }: ChatProviderProps) => { useEffect(() => { // Connect to the Socket.IO server chatService.connect(); + chatService.socket?.on("userLogin", (user: User) => { + console.log(`Current user login: ${user.login}`); + setUser(() => user); + }); // Listen for incoming messages recursively every 10 seconds chatService.socket?.on("listChats", (newChatList: ChatList) => { setChatList(() => newChatList); @@ -132,6 +142,7 @@ export const ChatProvider = ({ children }: ChatProviderProps) => { handleCloseChat, setValidationRequired, validationRequired, + user, }} > {children} diff --git a/frontend/src/services/chatClient.ts b/frontend/src/services/chatClient.ts index 41622d12..7eadb4d0 100644 --- a/frontend/src/services/chatClient.ts +++ b/frontend/src/services/chatClient.ts @@ -1,7 +1,6 @@ -import { io, Socket } from 'socket.io-client'; +import { io, Socket } from "socket.io-client"; class ChatService { - public socket: Socket | null; constructor() { @@ -14,15 +13,15 @@ class ChatService { transports: ["websocket", "polling", "flashsocket"], }); } - }; + } - public disconnect (): void { + public disconnect(): void { if (this.socket) { this.socket.disconnect(); } this.socket = null; console.log("Disconnected from the WebSocket server"); - }; + } } const chatService = new ChatService();