Skip to content

Commit

Permalink
fix(crux): WS session expire cleanup (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
robot9706 authored Sep 21, 2023
1 parent df2eebc commit 2059ab0
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 69 deletions.
10 changes: 9 additions & 1 deletion web/crux-ui/src/components/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import clsx from 'clsx'
import useTranslation from 'next-translate/useTranslation'
import { useRouter } from 'next/dist/client/router'
import Head from 'next/head'
import React from 'react'
import React, { useContext, useEffect } from 'react'
import useSWR from 'swr'
import Footer from './main/footer'
import { Sidebar } from './main/sidebar'
import Topbar from './main/top-bar'
import { WebSocketContext } from '@app/providers/websocket'

const sidebarWidth = 'w-[17rem]'
const mainWidth = 'w-[calc(100vw-17rem)]' // ViewWidth - sidebar
Expand Down Expand Up @@ -41,13 +42,20 @@ export interface LayoutProps {
export const Layout = (props: LayoutProps) => {
const { title, children, topBarContent } = props

const webSocketContext = useContext(WebSocketContext)
const { data: meta, error } = useSWR<UserMeta>(
API_USERS_ME,
configuredFetcher({
method: 'POST',
}),
)

useEffect(() => {
if (meta) {
webSocketContext.client?.reset()
}
}, [meta, webSocketContext.client])

const router = useRouter()
if (error) {
router.replace(ROUTE_LOGIN)
Expand Down
12 changes: 11 additions & 1 deletion web/crux-ui/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import { Translate } from 'next-translate'
import toast, { ToastOptions } from 'react-hot-toast'
import { fromApiError } from './error-responses'
import { DyoErrorDto, WsErrorMessage } from './models'
import WebSocketClient from './websockets/websocket-client'
import { ROUTE_LOGIN } from './routes'
import { NextRouter } from 'next/router'

export type DyoApiErrorHandler = (res: Response, setErrorValue?: FormikSetErrorValue) => Promise<void>

Expand Down Expand Up @@ -77,4 +80,11 @@ export const wsErrorHandler = (translator: Translator) => (message: WsErrorMessa
toaster(translation.toast)
}

export const defaultWsErrorHandler = (t: Translate) => wsErrorHandler(defaultTranslator(t))
export const defaultWsErrorHandler = (t: Translate, router: NextRouter) => (msg: WsErrorMessage) => {
const defaultErrorHandler = wsErrorHandler(defaultTranslator(t))
if (msg.status === WebSocketClient.ERROR_UNAUTHORIZE) {
router.push(ROUTE_LOGIN)
return
}
defaultErrorHandler(msg)
}
6 changes: 4 additions & 2 deletions web/crux-ui/src/hooks/use-websocket-translation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import { defaultWsErrorHandler } from '@app/errors'
import { WebSocketContext } from '@app/providers/websocket'
import { Translate } from 'next-translate'
import useTranslation from 'next-translate/useTranslation'
import { useRouter } from 'next/router'
import { useContext, useEffect } from 'react'

const useWebsocketTranslate = (t: Translate) => {
const wsContext = useContext(WebSocketContext)
const router = useRouter()
const { t: defaultTranslate } = useTranslation('common')

useEffect(() => {
if (wsContext.client) {
wsContext.client.setErrorHandler(msg => defaultWsErrorHandler(t)(msg))
wsContext.client.setErrorHandler(defaultWsErrorHandler(t, router))
}
return () => {
if (wsContext.client) {
wsContext.client.setErrorHandler(msg => defaultWsErrorHandler(defaultTranslate)(msg))
wsContext.client.setErrorHandler(defaultWsErrorHandler(defaultTranslate, router))
}
}
// eslint-disable-next-line react-hooks/exhaustive-deps
Expand Down
6 changes: 4 additions & 2 deletions web/crux-ui/src/providers/websocket.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { defaultWsErrorHandler } from '@app/errors'
import { isServerSide } from '@app/utils'
import WebSocketClient from '@app/websockets/websocket-client'
import useTranslation from 'next-translate/useTranslation'
import { useRouter } from 'next/router'
import React, { useState } from 'react'

interface WebSocketContextInterface {
Expand All @@ -14,6 +15,7 @@ export const WebSocketProvider = (props: React.PropsWithChildren<{}>) => {
const { children } = props

const { t } = useTranslation('common')
const router = useRouter()

const [wsClient] = useState(() => {
if (isServerSide()) {
Expand All @@ -22,8 +24,8 @@ export const WebSocketProvider = (props: React.PropsWithChildren<{}>) => {

const client = new WebSocketClient()

const wsErrorHandler = defaultWsErrorHandler(t)
client.setErrorHandler(msg => wsErrorHandler(msg))
const wsErrorHandler = defaultWsErrorHandler(t, router)
client.setErrorHandler(wsErrorHandler)

return client
})
Expand Down
40 changes: 37 additions & 3 deletions web/crux-ui/src/websockets/websocket-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import WebSocketClientEndpoint from './websocket-client-endpoint'
import WebSocketClientRoute from './websocket-client-route'

class WebSocketClient {
// NOTE(@robot9706): According to the WebSocket spec the 4000-4999 code range is available to applications
public static ERROR_UNAUTHORIZE = 4401

private logger = new Logger('WebSocketClient') // need to be explicit string because of production build uglification

private socket?: WebSocket
Expand All @@ -20,6 +23,8 @@ class WebSocketClient {

private errorHandler: WsErrorHandler = null

private kicked: boolean = false

get connected(): boolean {
return this.socket?.readyState === WebSocket.OPEN
}
Expand Down Expand Up @@ -74,6 +79,21 @@ class WebSocketClient {
this.errorHandler = handler
}

reset() {
if (this.socket && this.socket?.readyState !== WebSocket.CLOSED) {
return
}

if (!this.kicked) {
return
}

this.kicked = false
this.connectionAttemptCount = 0

this.reconnect()
}

private removeRoute(route: WebSocketClientRoute) {
const { path } = route
this.routes.delete(path)
Expand Down Expand Up @@ -105,6 +125,10 @@ class WebSocketClient {
return true
}

if (this.kicked) {
return false
}

// if there is already a connctionAttempt wait for the result
if (this.connectionAttempt && (await this.connectionAttempt)) {
// it was successful
Expand Down Expand Up @@ -203,16 +227,26 @@ class WebSocketClient {
this.routes.forEach(it => it.onSocketOpen())
}

const onClose = () => {
const onClose = (it: CloseEvent) => {
if (!resolved) {
resolved = true
setTimeout(() => resolve(false), failTimeout)
}

this.logger.info('Disconnected')

this.routes.forEach(it => it.onSocketClose())
this.reconnect()
this.errorHandler({
status: it.code,
message: it.reason,
})

this.routes.forEach(route => route.onSocketClose())

if (it.code === WebSocketClient.ERROR_UNAUTHORIZE) {
this.kicked = true
} else {
this.reconnect()
}
}

const onError = ev => {
Expand Down
1 change: 1 addition & 0 deletions web/crux/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module.exports = {
rules: {
'eol-last': ['error', 'always'],
'@typescript-eslint/semi': ['error', 'never'],
'@typescript-eslint/no-floating-promises': 'error',

// disabled
'@typescript-eslint/interface-name-prefix': 'off',
Expand Down
20 changes: 10 additions & 10 deletions web/crux/src/app/node/node.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ describe('NodeService', () => {
nodeService = module.get<NodeService>(NodeService)
})

it('startContainer should create an audit event', () => {
nodeService.startContainer('test-node-id', 'test-prefix', 'test-name')
it('startContainer should create an audit event', async () => {
await nodeService.startContainer('test-node-id', 'test-prefix', 'test-name')

expect(createAgentEventMock).toHaveBeenCalledWith('test-node-id', 'containerCommand', {
container: {
Expand All @@ -63,8 +63,8 @@ describe('NodeService', () => {
})
})

it('stopContainer should create an audit event', () => {
nodeService.stopContainer('test-node-id', 'test-prefix', 'test-name')
it('stopContainer should create an audit event', async () => {
await nodeService.stopContainer('test-node-id', 'test-prefix', 'test-name')

expect(createAgentEventMock).toHaveBeenCalledWith('test-node-id', 'containerCommand', {
container: {
Expand All @@ -75,8 +75,8 @@ describe('NodeService', () => {
})
})

it('restartContainer should create an audit event', () => {
nodeService.restartContainer('test-node-id', 'test-prefix', 'test-name')
it('restartContainer should create an audit event', async () => {
await nodeService.restartContainer('test-node-id', 'test-prefix', 'test-name')

expect(createAgentEventMock).toHaveBeenCalledWith('test-node-id', 'containerCommand', {
container: {
Expand All @@ -87,8 +87,8 @@ describe('NodeService', () => {
})
})

it('deleteContainer should create an audit event', () => {
nodeService.deleteContainer('test-node-id', 'test-prefix', 'test-name')
it('deleteContainer should create an audit event', async () => {
await nodeService.deleteContainer('test-node-id', 'test-prefix', 'test-name')

expect(createAgentEventMock).toHaveBeenCalledWith('test-node-id', 'containerCommand', {
container: {
Expand All @@ -99,8 +99,8 @@ describe('NodeService', () => {
})
})

it('deleteAllContainers should create an audit event', () => {
nodeService.deleteAllContainers('test-node-id', 'test-prefix')
it('deleteAllContainers should create an audit event', async () => {
await nodeService.deleteAllContainers('test-node-id', 'test-prefix')

expect(createAgentEventMock).toHaveBeenCalledWith('test-node-id', 'containerCommand', {
prefix: 'test-prefix',
Expand Down
11 changes: 9 additions & 2 deletions web/crux/src/app/token/jwt-auth.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Request as ExpressRequest } from 'express'
import { RequestAuthenticationData } from 'src/domain/identity'
import { CruxUnauthorizedException } from 'src/exception/crux-exception'
import KratosService, { hasKratosSession } from 'src/services/kratos.service'
import { WsClient } from 'src/websockets/common'
import { WS_TYPE_UNSUBSCRIBE, WsClient } from 'src/websockets/common'

export type AuthStrategyType = 'user-token' | 'deploy-token' | 'disabled'
export const AUTH_STRATEGY = 'auth-strategy'
Expand Down Expand Up @@ -109,14 +109,21 @@ export default class JwtAuthGuard extends AuthGuard('jwt') {

private canActivateWs(context: ExecutionContext): boolean {
const client: WsClient = context.switchToWs().getClient()
const message = this.reflector.get('message', context.getHandler())
if (client.disconnecting) {
// NOTE(@robot9706): When a client is disconnecting disallow any handlers
// except WsUnsubscribe for cleanup
return message === WS_TYPE_UNSUBSCRIBE
}

const req = client.connectionRequest as AuthorizedHttpRequest

const now = new Date().getTime()
const { sessionExpiresAt } = req

if (!sessionExpiresAt || sessionExpiresAt <= now) {
this.logger.debug('WebSocket session expired.')
client.close()

throw new CruxUnauthorizedException()
}

Expand Down
6 changes: 6 additions & 0 deletions web/crux/src/websockets/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ export type WsClient = WebSocket & {
connectionRequest: AuthorizedHttpRequest
sendWsMessage: WsSendClientMessage
subscriptions: Map<string, WsSubscription>
unsubscribeAll: VoidFunction
expireTimeout: NodeJS.Timeout

// NOTE(@robot9706): Used by jwt-auth.guard WS strategy to be able to call
// unsubscribe methods when the session is invalid (for cleanup)
disconnecting: boolean
}

export type WsTransform = (data: any | Promise<any> | Observable<any>) => Observable<any>
Expand Down
47 changes: 43 additions & 4 deletions web/crux/src/websockets/dyo.ws.adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {
Observable,
Subject,
catchError,
combineLatest,
filter,
finalize,
first,
from,
fromEvent,
Expand Down Expand Up @@ -48,6 +50,9 @@ export enum WebSocketReadyState {
CLOSED_STATE = 3,
}

// NOTE(@robot9706): According to the WebSocket spec the 4000-4999 code range is available to applications
const ERROR_UNAUTHORIZED = 4401

export default class DyoWsAdapter extends AbstractWsAdapter {
private readonly logger = new Logger(DyoWsAdapter.name)

Expand Down Expand Up @@ -280,8 +285,9 @@ export default class DyoWsAdapter extends AbstractWsAdapter {
client.token = uuid()
client.connectionRequest = req as AuthorizedHttpRequest
client.subscriptions = new Map()
client.disconnecting = false
client.sendWsMessage = msg => {
if (!msg || client.readyState !== WebSocketReadyState.OPEN_STATE) {
if (!msg || client.readyState !== WebSocketReadyState.OPEN_STATE || client.disconnecting) {
return
}

Expand All @@ -290,6 +296,9 @@ export default class DyoWsAdapter extends AbstractWsAdapter {
client.send(JSON.stringify(msg))
}
client.on(CLOSE_EVENT, () => this.onClientDisconnect(client))
client.unsubscribeAll = () => this.onClientDisconnect(client)

this.startClientExpiryTimer(client)

client.setup = new WsClientSetup(client, client.token, () => this.bindClientMessageHandlers(client))
client.setup.start()
Expand All @@ -298,13 +307,43 @@ export default class DyoWsAdapter extends AbstractWsAdapter {
this.logger.log(`Connected ${client.token} clients: ${this.server?.clients?.size}`)
}

private startClientExpiryTimer(client: WsClient) {
const { sessionExpiresAt } = client.connectionRequest

const now = new Date().getTime()
const expireTime = sessionExpiresAt - now

client.expireTimeout = setTimeout(() => {
this.logger.warn(`Session expired for ${client.token}`)
client.unsubscribeAll()
client.close(ERROR_UNAUTHORIZED, 'Expired')
}, expireTime)
}

private onClientDisconnect(client: WsClient) {
if (client.disconnecting) {
return
}

clearTimeout(client.expireTimeout)

client.disconnecting = true

this.logger.log(`Disconnected ${client.token} clients: ${this.server?.clients?.size}`)
WsMetrics.connections().dec()

this.routes.forEach(it => it.onClientDisconnect(client))

client?.setup?.onClientDisconnect()
const routeDisconnects = this.routes.map(it => it.onClientDisconnect(client))
combineLatest(routeDisconnects)
.pipe(
finalize(() => {
if (client?.subscriptions?.size > 0) {
this.logger.warn(`Client ${client.token} failed to cleanup all subscriptions!`)
}

client?.setup?.onClientDisconnect()
}),
)
.subscribe()
}

private findRouteByPath(path: string): [WsRoute, WsRouteMatch] {
Expand Down
Loading

0 comments on commit 2059ab0

Please sign in to comment.