From 371279f4ec9234290aaaa3da799d3c634252b383 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Thu, 18 Apr 2024 19:20:32 +0900 Subject: [PATCH] making mkConn of WarpTLS interruptible --- warp-tls/Network/Wai/Handler/WarpTLS.hs | 46 ++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/warp-tls/Network/Wai/Handler/WarpTLS.hs b/warp-tls/Network/Wai/Handler/WarpTLS.hs index 70cd3d034..a22c43d01 100644 --- a/warp-tls/Network/Wai/Handler/WarpTLS.hs +++ b/warp-tls/Network/Wai/Handler/WarpTLS.hs @@ -95,6 +95,8 @@ import UnliftIO.Exception ( try, ) import qualified UnliftIO.Exception as E +import UnliftIO.Concurrent (newEmptyMVar, putMVar, takeMVar, forkIOWithUnmask) +import UnliftIO.Timeout (timeout) ---------------------------------------------------------------- @@ -318,8 +320,18 @@ mkConn -> Socket -> params -> IO (Connection, Transport) -mkConn tlsset set s params = (safeRecv s 4096 >>= switch) `onException` close s +mkConn tlsset set s params = do + var <- newEmptyMVar + _ <- forkIOWithUnmask $ \umask -> do + let tm = settingsTimeout set * 1000000 + mct <- umask (timeout tm recvFirstBS) + putMVar var mct + mbs <- takeMVar var + case mbs of + Nothing -> throwIO IncompleteHeaders + Just bs -> switch bs where + recvFirstBS = safeRecv s 4096 `onException` close s switch firstBS | S.null firstBS = close s >> throwIO ClientClosedConnectionPrematurely | S.head firstBS == 0x16 = httpOverTls tlsset set s firstBS params @@ -335,22 +347,24 @@ httpOverTls -> S.ByteString -> params -> IO (Connection, Transport) -httpOverTls TLSSettings{..} _set s bs0 params = do - pool <- newBufferPool 2048 16384 - rawRecvN <- makeRecvN bs0 $ receive s pool - let recvN = wrappedRecvN rawRecvN - ctx <- TLS.contextNew (backend recvN) params - TLS.contextHookSetLogging ctx tlsLogging - TLS.handshake ctx - h2 <- (== Just "h2") <$> TLS.getNegotiatedProtocol ctx - isH2 <- I.newIORef h2 - writeBuffer <- createWriteBuffer 16384 - writeBufferRef <- I.newIORef writeBuffer - -- Creating a cache for leftover input data. - tls <- getTLSinfo ctx - mysa <- getSocketName s - return (conn ctx writeBufferRef isH2 mysa, tls) +httpOverTls TLSSettings{..} _set s bs0 params = + makeConn `onException` close s where + makeConn = do + pool <- newBufferPool 2048 16384 + rawRecvN <- makeRecvN bs0 $ receive s pool + let recvN = wrappedRecvN rawRecvN + ctx <- TLS.contextNew (backend recvN) params + TLS.contextHookSetLogging ctx tlsLogging + TLS.handshake ctx + h2 <- (== Just "h2") <$> TLS.getNegotiatedProtocol ctx + isH2 <- I.newIORef h2 + writeBuffer <- createWriteBuffer 16384 + writeBufferRef <- I.newIORef writeBuffer + -- Creating a cache for leftover input data. + tls <- getTLSinfo ctx + mysa <- getSocketName s + return (conn ctx writeBufferRef isH2 mysa, tls) backend recvN = TLS.Backend { TLS.backendFlush = return ()