diff --git a/web/crux/src/app/token/jwt-auth.guard.ts b/web/crux/src/app/token/jwt-auth.guard.ts index eb6ed6950..8c432f9ed 100644 --- a/web/crux/src/app/token/jwt-auth.guard.ts +++ b/web/crux/src/app/token/jwt-auth.guard.ts @@ -39,7 +39,7 @@ export default class JwtAuthGuard extends AuthGuard('jwt') { return await this.canActivateHttp(context, context.switchToHttp().getRequest(), strategy) } if (type === 'ws') { - return this.canActivateWs(context) + return await this.canActivateWs(context) } this.logger.error(`Invalid context ${type}`) @@ -107,7 +107,7 @@ export default class JwtAuthGuard extends AuthGuard('jwt') { req.headers.authorization = `Bearer ${token}` } - private canActivateWs(context: ExecutionContext): boolean { + private async canActivateWs(context: ExecutionContext): Promise { const client: WsClient = context.switchToWs().getClient() if (client.disconnecting) { // NOTE(@robot9706): When a client is disconnecting disallow any handlers @@ -123,6 +123,9 @@ export default class JwtAuthGuard extends AuthGuard('jwt') { if (!sessionExpiresAt || sessionExpiresAt <= now) { this.logger.debug('WebSocket session expired.') + + await client.unsubscribeAll() + client.close() throw new CruxUnauthorizedException() } diff --git a/web/crux/src/websockets/common.ts b/web/crux/src/websockets/common.ts index 60a30d981..f9a5241aa 100644 --- a/web/crux/src/websockets/common.ts +++ b/web/crux/src/websockets/common.ts @@ -62,6 +62,7 @@ export type WsClient = WebSocket & { connectionRequest: AuthorizedHttpRequest sendWsMessage: WsSendClientMessage subscriptions: Map + unsubscribeAll: () => Promise // NOTE(@robot9706): Used by jwt-auth.guard WS strategy to be able to call // unsubscribe methods when the session is invalid (for cleanup) diff --git a/web/crux/src/websockets/dyo.ws.adapter.ts b/web/crux/src/websockets/dyo.ws.adapter.ts index 55bed5a01..bf7f577b7 100644 --- a/web/crux/src/websockets/dyo.ws.adapter.ts +++ b/web/crux/src/websockets/dyo.ws.adapter.ts @@ -291,6 +291,7 @@ export default class DyoWsAdapter extends AbstractWsAdapter { client.send(JSON.stringify(msg)) } client.on(CLOSE_EVENT, () => this.onClientDisconnect(client)) + client.unsubscribeAll = () => this.onClientDisconnect(client) client.setup = new WsClientSetup(client, client.token, () => this.bindClientMessageHandlers(client)) client.setup.start() @@ -299,7 +300,11 @@ export default class DyoWsAdapter extends AbstractWsAdapter { this.logger.log(`Connected ${client.token} clients: ${this.server?.clients?.size}`) } - private async onClientDisconnect(client: WsClient) { + private async onClientDisconnect(client: WsClient): Promise { + if (client.disconnecting) { + return + } + client.disconnecting = true this.logger.log(`Disconnected ${client.token} clients: ${this.server?.clients?.size}`)