Skip to content

Commit

Permalink
workaround ab (ApacheBench) in keep-alive mode (#260)
Browse files Browse the repository at this point in the history
Motivation:

`ab -k` behaves weirdly: it'll send an HTTP/1.0 request with keep-alive set
but stops doing anything at all if the server doesn't also set Connection: keep-alive
which our example HTTP1Server didn't do.

Modifications:

In the HTTP1Server example if we receive an HTTP/1.0 request with
Connection: keep-alive, we'll now set keep-alive too.

Result:

ab -k doesn't get stuck anymore.
  • Loading branch information
weissi authored and normanmaurer committed Apr 10, 2018
1 parent b6681c0 commit d94ab56
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ echo -e 'GET /dynamic/count-to-ten HTTP/1.1\r\nConnection: close\r\n\r\n' | \
backslash_r=$(echo -ne '\r')
cat > "$tmp/expected" <<EOF
HTTP/1.1 200 OK$backslash_r
Connection: close$backslash_r
transfer-encoding: chunked$backslash_r
$backslash_r
1$backslash_r
Expand Down
24 changes: 24 additions & 0 deletions IntegrationTests/tests_01_http/test_22_http_1.0_keep_alive.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
##===----------------------------------------------------------------------===##
##
## This source file is part of the SwiftNIO open source project
##
## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
## Licensed under Apache License v2.0
##
## See LICENSE.txt for license information
## See CONTRIBUTORS.txt for the list of SwiftNIO project authors
##
## SPDX-License-Identifier: Apache-2.0
##
##===----------------------------------------------------------------------===##

source defines.sh

token=$(create_token)
start_server "$token"
do_curl "$token" -H 'connection: keep-alive' -v --http1.0 \
"http://foobar.com/dynamic/info" > "$tmp/out_actual" 2>&1
grep -qi '< Connection: keep-alive' "$tmp/out_actual"
grep -qi '< HTTP/1.0 200 OK' "$tmp/out_actual"
stop_server "$token"
45 changes: 34 additions & 11 deletions Sources/NIOHTTP1Server/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ extension String {
}
}

private func httpResponseHead(request: HTTPRequestHead, status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) -> HTTPResponseHead {
var head = HTTPResponseHead(version: request.version, status: status, headers: headers)
let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { $0.lowercased() }

if !connectionHeaders.contains("keep-alive") && !connectionHeaders.contains("close") {
// the user hasn't pre-set either 'keep-alive' or 'close', so we might need to add headers

switch (request.isKeepAlive, request.version.major, request.version.minor) {
case (true, 1, 0):
// HTTP/1.0 and the request has 'Connection: keep-alive', we should mirror that
head.headers.add(name: "Connection", value: "keep-alive")
case (false, 1, let n) where n >= 1:
// HTTP/1.1 (or treated as such) and the request has 'Connection: close', we should mirror that
head.headers.add(name: "Connection", value: "close")
default:
// we should match the default or are dealing with some HTTP that we don't support, let's leave as is
()
}
}
return head
}

private final class HTTPHandler: ChannelInboundHandler {
private enum FileIOMethod {
case sendfile
Expand Down Expand Up @@ -104,7 +126,7 @@ private final class HTTPHandler: ChannelInboundHandler {
self.buffer.write(string: response)
var headers = HTTPHeaders()
headers.add(name: "Content-Length", value: "\(response.utf8.count)")
ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: self.infoSavedRequestHead!.version, status: .ok, headers: headers))), promise: nil)
ctx.write(self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil)
ctx.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil)
self.completeResponse(ctx, trailers: nil, promise: nil)
}
Expand All @@ -118,6 +140,7 @@ private final class HTTPHandler: ChannelInboundHandler {
switch request {
case .head(let request):
self.keepAlive = request.isKeepAlive
self.infoSavedRequestHead = request
self.state.requestReceived()
if balloonInMemory {
self.buffer.clear()
Expand All @@ -135,7 +158,7 @@ private final class HTTPHandler: ChannelInboundHandler {
if balloonInMemory {
var headers = HTTPHeaders()
headers.add(name: "Content-Length", value: "\(self.buffer.readableBytes)")
ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 0), status: .ok, headers: headers))), promise: nil)
ctx.write(self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil)
ctx.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil)
self.completeResponse(ctx, trailers: nil, promise: nil)
} else {
Expand Down Expand Up @@ -185,7 +208,7 @@ private final class HTTPHandler: ChannelInboundHandler {
self.completeResponse(ctx, trailers: nil, promise: nil)
}
}
ctx.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead(version: request.version, status: .ok))), promise: nil)
ctx.writeAndFlush(self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil)
doNext()
case .end:
self.state.requestComplete()
Expand All @@ -212,7 +235,7 @@ private final class HTTPHandler: ChannelInboundHandler {
}
}
}
ctx.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead(version: request.version, status: .ok))), promise: nil)
ctx.writeAndFlush(self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil)
doNext()
case .end:
self.state.requestComplete()
Expand Down Expand Up @@ -254,7 +277,7 @@ private final class HTTPHandler: ChannelInboundHandler {
self.keepAlive = request.isKeepAlive
self.state.requestReceived()
guard !request.uri.containsDotDot() else {
let response = HTTPResponseHead(version: request.version, status: .forbidden)
let response = httpResponseHead(request: request, status: .forbidden)
ctx.write(self.wrapOutboundOut(.head(response)), promise: nil)
self.completeResponse(ctx, trailers: nil, promise: nil)
return
Expand All @@ -263,7 +286,7 @@ private final class HTTPHandler: ChannelInboundHandler {
do {
let file = try FileHandle(path: path)
let region = try FileRegion(fileHandle: file)
var response = HTTPResponseHead(version: request.version, status: .ok)
var response = httpResponseHead(request: request, status: .ok)

response.headers.add(name: "Content-Length", value: "\(region.endIndex)")
response.headers.add(name: "Content-Type", value: "text/plain; charset=utf-8")
Expand All @@ -287,7 +310,7 @@ private final class HTTPHandler: ChannelInboundHandler {
return p.futureResult
}.thenIfError { error in
if !responseStarted {
let response = HTTPResponseHead(version: request.version, status: .ok)
let response = httpResponseHead(request: request, status: .ok)
ctx.write(self.wrapOutboundOut(.head(response)), promise: nil)
var buffer = ctx.channel.allocator.buffer(capacity: 100)
buffer.write(string: "fail: \(error)")
Expand Down Expand Up @@ -319,15 +342,15 @@ private final class HTTPHandler: ChannelInboundHandler {
switch error {
case let e as IOError where e.errnoCode == ENOENT:
body.write(staticString: "IOError (not found)\r\n")
return HTTPResponseHead(version: request.version, status: .notFound)
return httpResponseHead(request: request, status: .notFound)
case let e as IOError:
body.write(staticString: "IOError (other)\r\n")
body.write(string: e.description)
body.write(staticString: "\r\n")
return HTTPResponseHead(version: request.version, status: .notFound)
return httpResponseHead(request: request, status: .notFound)
default:
body.write(string: "\(type(of: error)) error\r\n")
return HTTPResponseHead(version: request.version, status: .internalServerError)
return httpResponseHead(request: request, status: .internalServerError)
}
}()
body.write(string: "\(error)")
Expand Down Expand Up @@ -381,7 +404,7 @@ private final class HTTPHandler: ChannelInboundHandler {
self.keepAlive = request.isKeepAlive
self.state.requestReceived()

var responseHead = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
var responseHead = httpResponseHead(request: request, status: HTTPResponseStatus.ok)
responseHead.headers.add(name: "content-length", value: "12")
let response = HTTPServerResponsePart.head(responseHead)
ctx.write(self.wrapOutboundOut(response), promise: nil)
Expand Down

0 comments on commit d94ab56

Please sign in to comment.