From e9873e59c3c7ea0ce1d244104d3a4d0af4682487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bence=20Nagygy=C3=B6rgy?= Date: Mon, 18 Sep 2023 13:54:41 +0200 Subject: [PATCH] fix(web): websocket close --- web/crux-ui/e2e/utils/test.fixture.ts | 2 +- web/crux-ui/src/providers/websocket.tsx | 11 ++++- .../src/websockets/websocket-client.ts | 19 ++++++-- .../config.bundle/config.bundle.ws.gateway.ts | 4 +- web/crux/src/app/deploy/deploy.ws.gateway.ts | 4 +- web/crux/src/app/token/jwt-auth.guard.ts | 5 +-- .../src/app/version/version.ws.gateway.ts | 4 +- web/crux/src/websockets/common.ts | 3 +- web/crux/src/websockets/dyo.ws.adapter.ts | 37 +++++++++++---- web/crux/src/websockets/namespace.spec.ts | 12 ++--- web/crux/src/websockets/namespace.ts | 41 +++++++++++------ web/crux/src/websockets/route.ts | 45 +++++++++++-------- 12 files changed, 120 insertions(+), 67 deletions(-) diff --git a/web/crux-ui/e2e/utils/test.fixture.ts b/web/crux-ui/e2e/utils/test.fixture.ts index 6168eea51b..40e79db5d0 100644 --- a/web/crux-ui/e2e/utils/test.fixture.ts +++ b/web/crux-ui/e2e/utils/test.fixture.ts @@ -22,7 +22,7 @@ test.beforeEach(async ({ page }, testInfo) => { return } - console.log(`[${testInfo.title}] ${type.toUpperCase()} ${it.text()}`) + console.info(`[${testInfo.title}] ${type.toUpperCase()} ${it.text()}`) }) if (CPU_THROTTLE) { diff --git a/web/crux-ui/src/providers/websocket.tsx b/web/crux-ui/src/providers/websocket.tsx index 2ad808f20c..ff4de53468 100644 --- a/web/crux-ui/src/providers/websocket.tsx +++ b/web/crux-ui/src/providers/websocket.tsx @@ -1,7 +1,9 @@ import { defaultWsErrorHandler } from '@app/errors' +import { ROUTE_LOGIN } from '@app/routes' import { isServerSide } from '@app/utils' import WebSocketClient from '@app/websockets/websocket-client' import useTranslation from 'next-translate/useTranslation' +import { useRouter } from 'next/router' import React, { useState } from 'react' interface WebSocketContextInterface { @@ -14,6 +16,7 @@ export const WebSocketProvider = (props: React.PropsWithChildren<{}>) => { const { children } = props const { t } = useTranslation('common') + const router = useRouter() const [wsClient] = useState(() => { if (isServerSide()) { @@ -23,7 +26,13 @@ export const WebSocketProvider = (props: React.PropsWithChildren<{}>) => { const client = new WebSocketClient() const wsErrorHandler = defaultWsErrorHandler(t) - client.setErrorHandler(msg => wsErrorHandler(msg)) + client.setErrorHandler(msg => { + if (msg.status === WebSocketClient.ERROR_SESSION_EXPIRED) { + router.push(ROUTE_LOGIN) + return + } + wsErrorHandler(msg) + }) return client }) diff --git a/web/crux-ui/src/websockets/websocket-client.ts b/web/crux-ui/src/websockets/websocket-client.ts index 244ba54c97..65067506c9 100644 --- a/web/crux-ui/src/websockets/websocket-client.ts +++ b/web/crux-ui/src/websockets/websocket-client.ts @@ -6,6 +6,8 @@ import WebSocketClientEndpoint from './websocket-client-endpoint' import WebSocketClientRoute from './websocket-client-route' class WebSocketClient { + public static ERROR_SESSION_EXPIRED = 4000 + private logger = new Logger('WebSocketClient') // need to be explicit string because of production build uglification private socket?: WebSocket @@ -91,7 +93,9 @@ class WebSocketClient { const ws = this.socket this.socket = null - ws.close() + if (ws.readyState === WebSocket.OPEN) { + ws.close() + } this.destroyListeners?.call(null) } @@ -203,7 +207,7 @@ class WebSocketClient { this.routes.forEach(it => it.onSocketOpen()) } - const onClose = () => { + const onClose = (it: CloseEvent) => { if (!resolved) { resolved = true setTimeout(() => resolve(false), failTimeout) @@ -211,8 +215,15 @@ class WebSocketClient { this.logger.info('Disconnected') - this.routes.forEach(it => it.onSocketClose()) - this.reconnect() + this.routes.forEach(route => route.onSocketClose()) + if (it.code === WebSocketClient.ERROR_SESSION_EXPIRED) { + this.errorHandler({ + status: WebSocketClient.ERROR_SESSION_EXPIRED, + message: it.reason, + }) + } else { + this.reconnect() + } } const onError = ev => { diff --git a/web/crux/src/app/config.bundle/config.bundle.ws.gateway.ts b/web/crux/src/app/config.bundle/config.bundle.ws.gateway.ts index 634c4509d7..1f69bca17d 100644 --- a/web/crux/src/app/config.bundle/config.bundle.ws.gateway.ts +++ b/web/crux/src/app/config.bundle/config.bundle.ws.gateway.ts @@ -87,15 +87,13 @@ export default class ConfigBundleWebSocketGateway { @SocketClient() client: WsClient, @ConfigBundleId() configBundleId: string, @SocketSubscription() subscription: WsSubscription, - ): Promise { + ) { const data = await this.service.onEditorLeft(configBundleId, client.token) const message: WsMessage = { type: WS_TYPE_EDITOR_LEFT, data, } subscription.sendToAllExcept(client, message) - - return true } @SubscribeMessage(WS_TYPE_PATCH_CONFIG_BUNDLE) diff --git a/web/crux/src/app/deploy/deploy.ws.gateway.ts b/web/crux/src/app/deploy/deploy.ws.gateway.ts index 47fe595770..c18210a4b0 100644 --- a/web/crux/src/app/deploy/deploy.ws.gateway.ts +++ b/web/crux/src/app/deploy/deploy.ws.gateway.ts @@ -137,7 +137,7 @@ export default class DeployWebSocketGateway { @SocketClient() client: WsClient, @DeploymentId() deploymentId: string, @SocketSubscription() subscription: WsSubscription, - ): Promise { + ) { const data = await this.service.onEditorLeft(deploymentId, client.token) const message: WsMessage = { type: WS_TYPE_EDITOR_LEFT, @@ -150,8 +150,6 @@ export default class DeployWebSocketGateway { this.deploymentEventCompleters.get(key).next(undefined) this.deploymentEventCompleters.delete(key) } - - return true } @AuditLogLevel('disabled') diff --git a/web/crux/src/app/token/jwt-auth.guard.ts b/web/crux/src/app/token/jwt-auth.guard.ts index 8c432f9edf..aed4f58c85 100644 --- a/web/crux/src/app/token/jwt-auth.guard.ts +++ b/web/crux/src/app/token/jwt-auth.guard.ts @@ -109,10 +109,10 @@ export default class JwtAuthGuard extends AuthGuard('jwt') { private async canActivateWs(context: ExecutionContext): Promise { const client: WsClient = context.switchToWs().getClient() + const message = this.reflector.get('message', context.getHandler()) if (client.disconnecting) { // NOTE(@robot9706): When a client is disconnecting disallow any handlers // except WsUnsubscribe for cleanup - const message = this.reflector.get('message', context.getHandler()) return message === WS_TYPE_UNSUBSCRIBE } @@ -124,9 +124,6 @@ export default class JwtAuthGuard extends AuthGuard('jwt') { if (!sessionExpiresAt || sessionExpiresAt <= now) { this.logger.debug('WebSocket session expired.') - await client.unsubscribeAll() - - client.close() throw new CruxUnauthorizedException() } diff --git a/web/crux/src/app/version/version.ws.gateway.ts b/web/crux/src/app/version/version.ws.gateway.ts index 9cec6598bf..04d2f0619e 100644 --- a/web/crux/src/app/version/version.ws.gateway.ts +++ b/web/crux/src/app/version/version.ws.gateway.ts @@ -101,15 +101,13 @@ export default class VersionWebSocketGateway { @SocketClient() client: WsClient, @VersionId() versionId: string, @SocketSubscription() subscription: WsSubscription, - ): Promise { + ) { const data = await this.service.onEditorLeft(versionId, client.token) const message: WsMessage = { type: WS_TYPE_EDITOR_LEFT, data, } subscription.sendToAllExcept(client, message) - - return true } @AuditLogLevel('disabled') diff --git a/web/crux/src/websockets/common.ts b/web/crux/src/websockets/common.ts index f9a5241aa9..8a2a44f9d1 100644 --- a/web/crux/src/websockets/common.ts +++ b/web/crux/src/websockets/common.ts @@ -62,7 +62,8 @@ export type WsClient = WebSocket & { connectionRequest: AuthorizedHttpRequest sendWsMessage: WsSendClientMessage subscriptions: Map - unsubscribeAll: () => Promise + unsubscribeAll: VoidFunction + expireTimeout: NodeJS.Timeout // NOTE(@robot9706): Used by jwt-auth.guard WS strategy to be able to call // unsubscribe methods when the session is invalid (for cleanup) diff --git a/web/crux/src/websockets/dyo.ws.adapter.ts b/web/crux/src/websockets/dyo.ws.adapter.ts index bf7f577b71..54caa3426a 100644 --- a/web/crux/src/websockets/dyo.ws.adapter.ts +++ b/web/crux/src/websockets/dyo.ws.adapter.ts @@ -11,6 +11,7 @@ import { catchError, filter, first, + forkJoin, from, fromEvent, mergeAll, @@ -48,6 +49,8 @@ export enum WebSocketReadyState { CLOSED_STATE = 3, } +const ERROR_SESSION_EXPIRED = 4000 + export default class DyoWsAdapter extends AbstractWsAdapter { private readonly logger = new Logger(DyoWsAdapter.name) @@ -266,7 +269,7 @@ export default class DyoWsAdapter extends AbstractWsAdapter { if (message.type === WS_TYPE_SUBSCRIBE) { res = await route.onSubscribe(client, match, message) } else if (message.type === WS_TYPE_UNSUBSCRIBE) { - res = await route.onUnsubscribe(client, match, message) + res = route.onUnsubscribe(client, match, message) } else { const err = new Error(`Invalid subscription type ${message.type}`) this.logger.verbose(err) @@ -293,6 +296,8 @@ export default class DyoWsAdapter extends AbstractWsAdapter { client.on(CLOSE_EVENT, () => this.onClientDisconnect(client)) client.unsubscribeAll = () => this.onClientDisconnect(client) + this.startClientExpiryTimer(client) + client.setup = new WsClientSetup(client, client.token, () => this.bindClientMessageHandlers(client)) client.setup.start() @@ -300,23 +305,39 @@ export default class DyoWsAdapter extends AbstractWsAdapter { this.logger.log(`Connected ${client.token} clients: ${this.server?.clients?.size}`) } - private async onClientDisconnect(client: WsClient): Promise { + private startClientExpiryTimer(client: WsClient) { + const { sessionExpiresAt } = client.connectionRequest + + const now = new Date().getTime() + const expireTime = sessionExpiresAt - now + + client.expireTimeout = setTimeout(() => { + this.logger.warn(`Session expired for ${client.token}`) + client.unsubscribeAll() + client.close(ERROR_SESSION_EXPIRED, 'Expired') + }, expireTime) + } + + private onClientDisconnect(client: WsClient) { if (client.disconnecting) { return } + clearTimeout(client.expireTimeout) + client.disconnecting = true this.logger.log(`Disconnected ${client.token} clients: ${this.server?.clients?.size}`) WsMetrics.connections().dec() - await Promise.all(this.routes.map(it => it.onClientDisconnect(client))) - - if (client?.subscriptions?.size > 0) { - this.logger.warn(`Client ${client.token} failed to cleanup all subscriptions!`) - } + const routeDisconnects = this.routes.map(it => it.onClientDisconnect(client)) + forkJoin(routeDisconnects).subscribe(() => { + if (client?.subscriptions?.size > 0) { + this.logger.warn(`Client ${client.token} failed to cleanup all subscriptions!`) + } - client?.setup?.onClientDisconnect() + client?.setup?.onClientDisconnect() + }) } private findRouteByPath(path: string): [WsRoute, WsRouteMatch] { diff --git a/web/crux/src/websockets/namespace.spec.ts b/web/crux/src/websockets/namespace.spec.ts index 4d6a5a8fa1..7ceceb4644 100644 --- a/web/crux/src/websockets/namespace.spec.ts +++ b/web/crux/src/websockets/namespace.spec.ts @@ -240,13 +240,13 @@ describe('WsNamespace', () => { }) it('should return UnsubscribeResult with a null when the client is not subscribed', async () => { - const result = await namespace.onUnsubscribe(client, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(client, unsubscribeMessage)) expect(result.res).toBe(null) }) it('should return UnsubscribeResult with shouldRemove true when the client is not subscribed and there is no other client subscribed', async () => { - const result = await namespace.onUnsubscribe(client, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(client, unsubscribeMessage)) expect(result.shouldRemove).toBe(true) }) @@ -335,7 +335,7 @@ describe('WsNamespace', () => { }) it('should return UnsubscribeResult with shouldRemove false when the client is not subscribed', async () => { - const result = await namespace.onUnsubscribe(client, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(client, unsubscribeMessage)) expect(result.shouldRemove).toBe(false) }) @@ -343,19 +343,19 @@ describe('WsNamespace', () => { it('should return UnsubscribeResult with shouldRemove false', async () => { namespace.onSubscribe(client, callbacks, subscribeMessage) - const result = await namespace.onUnsubscribe(client, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(client, unsubscribeMessage)) expect(result.shouldRemove).toBe(false) }) it('should return UnsubscribeResult with an observable containing the correct unsubbed message', async () => { - const result = await namespace.onUnsubscribe(subscribedClient, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(subscribedClient, unsubscribeMessage)) expect(result.res).toEqual(successfulUnsubscribeMessage) }) it('should return UnsubscribeResult with shouldRemove true, when this was the last client', async () => { - const result = await namespace.onUnsubscribe(subscribedClient, unsubscribeMessage) + const result = await firstValueFrom(namespace.onUnsubscribe(subscribedClient, unsubscribeMessage)) expect(result.shouldRemove).toBe(true) }) diff --git a/web/crux/src/websockets/namespace.ts b/web/crux/src/websockets/namespace.ts index ef5213fb12..6a28f2b28b 100644 --- a/web/crux/src/websockets/namespace.ts +++ b/web/crux/src/websockets/namespace.ts @@ -1,5 +1,5 @@ import { Logger } from '@nestjs/common' -import { EMPTY, Observable, Subject, filter, first, firstValueFrom, map, mergeWith, of, takeUntil } from 'rxjs' +import { EMPTY, Observable, Subject, filter, first, map, mergeWith, of, takeUntil } from 'rxjs' import { SubscriptionMessage, WS_TYPE_SUBBED, @@ -87,16 +87,16 @@ export default class WsNamespace implements WsSubscription { return of(res).pipe(first()) } - async onUnsubscribe(client: WsClient, message: WsMessage | null): Promise { + onUnsubscribe(client: WsClient, message: WsMessage | null): Observable { const { token } = client const resources = this.clients.get(token) if (!resources) { this.logger.warn(`undefined resource for '${token}'`) - return { + return of({ res: null, shouldRemove: this.clients.size < 1, - } + }) } // When the connection is killed, we get an empty message, @@ -113,14 +113,27 @@ export default class WsNamespace implements WsSubscription { const { unsubscribe, transform, completer } = resources if (unsubscribe) { - const unsubscribeResult = await firstValueFrom(transform(unsubscribe(message))) - if (!unsubscribeResult) { - this.logger.warn(`${this.path} @WsUnsubscribe returned undefined`) - return { - res: null, - shouldRemove: this.clients.size < 1, - } - } + return transform(unsubscribe(message)).pipe( + first(), + map(() => { + client.subscriptions.delete(this.path) + + completer.next(undefined) + this.clients.delete(token) + + this.logger.verbose(`${token} unsubscribed`) + + return { + res: { + type: WS_TYPE_UNSUBBED, + data: { + path: this.path, + }, + }, + shouldRemove: this.clients.size < 1, + } + }), + ) } client.subscriptions.delete(this.path) @@ -130,7 +143,7 @@ export default class WsNamespace implements WsSubscription { this.logger.verbose(`${token} unsubscribed`) - return { + return of({ res: { type: WS_TYPE_UNSUBBED, data: { @@ -138,7 +151,7 @@ export default class WsNamespace implements WsSubscription { }, }, shouldRemove: this.clients.size < 1, - } + }) } onMessage(client: WsClient, message: WsMessage): Observable { diff --git a/web/crux/src/websockets/route.ts b/web/crux/src/websockets/route.ts index 6637dbe317..9765d95162 100644 --- a/web/crux/src/websockets/route.ts +++ b/web/crux/src/websockets/route.ts @@ -1,6 +1,6 @@ import { Logger } from '@nestjs/common' import { MessageMappingProperties } from '@nestjs/websockets' -import { EMPTY, Observable, firstValueFrom, of } from 'rxjs' +import { EMPTY, Observable, firstValueFrom, forkJoin, map, of, switchMap } from 'rxjs' import WsMetrics from 'src/shared/metrics/ws.metrics' import { SubscriptionMessage, @@ -119,16 +119,14 @@ export default class WsRoute { return ns.onSubscribe(client, callbacks, message) } - async onUnsubscribe( + onUnsubscribe( client: WsClient, match: WsRouteMatch, message: WsMessage, - ): Promise>> { + ): Observable> { const { path } = match - const res = await this.removeClientFromNamespace(client, path, message) - - return res ? of(res) : EMPTY + return this.removeClientFromNamespace(client, path, message).pipe(switchMap(it => of(it) ?? EMPTY)) } onClientBind(client: WsClient, handlers: MessageMappingProperties[], transform: WsTransform) { @@ -169,12 +167,18 @@ export default class WsRoute { }) } - async onClientDisconnect(client: WsClient): Promise { + onClientDisconnect(client: WsClient): Observable { const subscriptionPaths = Array.from(client.subscriptions.keys()) - await Promise.all(Array.from(subscriptionPaths).map(it => this.removeClientFromNamespace(client, it, null))) + const unsubscribes = Array.from(subscriptionPaths).map(it => this.removeClientFromNamespace(client, it, null)) - this.callbacks.delete(client.token) + return unsubscribes.length + ? forkJoin(unsubscribes).pipe( + map(() => { + this.callbacks.delete(client.token) + }), + ) + : of() } private upsertNamespace(match: WsRouteMatch): WsNamespace { @@ -193,25 +197,28 @@ export default class WsRoute { return ns } - private async removeClientFromNamespace( + private removeClientFromNamespace( client: WsClient, namespacePath: string, message: WsMessage | null, - ): Promise> { + ): Observable> { const ns = this.namespaces.get(namespacePath) if (!ns) { - return null + return EMPTY } - const { res, shouldRemove } = await ns.onUnsubscribe(client, message) - if (shouldRemove) { - this.namespaces.delete(namespacePath) - this.logger.verbose(`Namespace deleted ${namespacePath}`) + return ns.onUnsubscribe(client, message).pipe( + map(({ res, shouldRemove }) => { + if (shouldRemove) { + this.namespaces.delete(namespacePath) + this.logger.verbose(`Namespace deleted ${namespacePath}`) - WsMetrics.routeNamespaces(this.path).set(this.countNamespaces()) - } + WsMetrics.routeNamespaces(this.path).set(this.countNamespaces()) + } - return res + return res + }), + ) } private pathFromParts(parts: string[], from?: number, to?: number): string {