diff --git a/IntegrationTests/tests_01_http/test_18_close_with_no_keepalive.sh b/IntegrationTests/tests_01_http/test_18_close_with_no_keepalive.sh index e687fba393..7ffd108bca 100644 --- a/IntegrationTests/tests_01_http/test_18_close_with_no_keepalive.sh +++ b/IntegrationTests/tests_01_http/test_18_close_with_no_keepalive.sh @@ -28,6 +28,7 @@ echo -e 'GET /dynamic/count-to-ten HTTP/1.1\r\nConnection: close\r\n\r\n' | \ backslash_r=$(echo -ne '\r') cat > "$tmp/expected" < "$tmp/out_actual" 2>&1 +grep -qi '< Connection: keep-alive' "$tmp/out_actual" +grep -qi '< HTTP/1.0 200 OK' "$tmp/out_actual" +stop_server "$token" diff --git a/Sources/NIOHTTP1Server/main.swift b/Sources/NIOHTTP1Server/main.swift index dd64ef5e08..991017453d 100644 --- a/Sources/NIOHTTP1Server/main.swift +++ b/Sources/NIOHTTP1Server/main.swift @@ -33,6 +33,28 @@ extension String { } } +private func httpResponseHead(request: HTTPRequestHead, status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) -> HTTPResponseHead { + var head = HTTPResponseHead(version: request.version, status: status, headers: headers) + let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { $0.lowercased() } + + if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") { + // the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers + + switch (request.isKeepAlive, request.version.major, request.version.minor) { + case (true, 1, 0): + // HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that + head.headers.add(name: "Connection", value: "keep-alive") + case (false, 1, let n) where n >= 1: + // HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that + head.headers.add(name: "Connection", value: "close") + default: + // we should match the default or are dealing with some HTTP that we don't support, let's leave as is + () + } + } + return head +} + private final class HTTPHandler: ChannelInboundHandler { private enum FileIOMethod { case sendfile @@ -104,7 +126,7 @@ private final class HTTPHandler: ChannelInboundHandler { self.buffer.write(string: response) var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "\(response.utf8.count)") - ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: self.infoSavedRequestHead!.version, status: .ok, headers: headers))), promise: nil) + ctx.write(self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil) ctx.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) self.completeResponse(ctx, trailers: nil, promise: nil) } @@ -118,6 +140,7 @@ private final class HTTPHandler: ChannelInboundHandler { switch request { case .head(let request): self.keepAlive = request.isKeepAlive + self.infoSavedRequestHead = request self.state.requestReceived() if balloonInMemory { self.buffer.clear() @@ -135,7 +158,7 @@ private final class HTTPHandler: ChannelInboundHandler { if balloonInMemory { var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "\(self.buffer.readableBytes)") - ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 0), status: .ok, headers: headers))), promise: nil) + ctx.write(self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil) ctx.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) self.completeResponse(ctx, trailers: nil, promise: nil) } else { @@ -185,7 +208,7 @@ private final class HTTPHandler: ChannelInboundHandler { self.completeResponse(ctx, trailers: nil, promise: nil) } } - ctx.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead(version: request.version, status: .ok))), promise: nil) + ctx.writeAndFlush(self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil) doNext() case .end: self.state.requestComplete() @@ -212,7 +235,7 @@ private final class HTTPHandler: ChannelInboundHandler { } } } - ctx.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead(version: request.version, status: .ok))), promise: nil) + ctx.writeAndFlush(self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil) doNext() case .end: self.state.requestComplete() @@ -254,7 +277,7 @@ private final class HTTPHandler: ChannelInboundHandler { self.keepAlive = request.isKeepAlive self.state.requestReceived() guard !request.uri.containsDotDot() else { - let response = HTTPResponseHead(version: request.version, status: .forbidden) + let response = httpResponseHead(request: request, status: .forbidden) ctx.write(self.wrapOutboundOut(.head(response)), promise: nil) self.completeResponse(ctx, trailers: nil, promise: nil) return @@ -263,7 +286,7 @@ private final class HTTPHandler: ChannelInboundHandler { do { let file = try FileHandle(path: path) let region = try FileRegion(fileHandle: file) - var response = HTTPResponseHead(version: request.version, status: .ok) + var response = httpResponseHead(request: request, status: .ok) response.headers.add(name: "Content-Length", value: "\(region.endIndex)") response.headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") @@ -287,7 +310,7 @@ private final class HTTPHandler: ChannelInboundHandler { return p.futureResult }.thenIfError { error in if !responseStarted { - let response = HTTPResponseHead(version: request.version, status: .ok) + let response = httpResponseHead(request: request, status: .ok) ctx.write(self.wrapOutboundOut(.head(response)), promise: nil) var buffer = ctx.channel.allocator.buffer(capacity: 100) buffer.write(string: "fail: \(error)") @@ -319,15 +342,15 @@ private final class HTTPHandler: ChannelInboundHandler { switch error { case let e as IOError where e.errnoCode == ENOENT: body.write(staticString: "IOError (not found)\r\n") - return HTTPResponseHead(version: request.version, status: .notFound) + return httpResponseHead(request: request, status: .notFound) case let e as IOError: body.write(staticString: "IOError (other)\r\n") body.write(string: e.description) body.write(staticString: "\r\n") - return HTTPResponseHead(version: request.version, status: .notFound) + return httpResponseHead(request: request, status: .notFound) default: body.write(string: "\(type(of: error)) error\r\n") - return HTTPResponseHead(version: request.version, status: .internalServerError) + return httpResponseHead(request: request, status: .internalServerError) } }() body.write(string: "\(error)") @@ -381,7 +404,7 @@ private final class HTTPHandler: ChannelInboundHandler { self.keepAlive = request.isKeepAlive self.state.requestReceived() - var responseHead = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok) + var responseHead = httpResponseHead(request: request, status: HTTPResponseStatus.ok) responseHead.headers.add(name: "content-length", value: "12") let response = HTTPServerResponsePart.head(responseHead) ctx.write(self.wrapOutboundOut(response), promise: nil)