Skip to content

Commit 1e03f82

Browse files
committed
Adding Timeout
1 parent 48f7b17 commit 1e03f82

File tree

7 files changed

+174
-16
lines changed

7 files changed

+174
-16
lines changed

Sources/HTTPConnection.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ final class HTTPConnection: Hashable {
4646
HTTPRequestSequence(bytes: socket.bytes)
4747
}
4848

49-
func sendResponse(_ response: HTTPResponse, for request: HTTPRequest) throws {
50-
try socket.write(HTTPResponseEncoder.encodeResponse(response, for: request))
49+
func sendResponse(_ response: HTTPResponse) throws {
50+
try socket.write(HTTPResponseEncoder.encodeResponse(response))
5151
}
5252

5353
func close() throws {

Sources/HTTPResponse.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,9 @@ public struct HTTPResponse {
4747
self.body = body
4848
}
4949
}
50+
51+
extension HTTPResponse {
52+
var shouldKeepAlive: Bool {
53+
headers[.connection]?.caseInsensitiveCompare("keep-alive") == .orderedSame
54+
}
55+
}

Sources/HTTPResponseEncoder.swift

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,15 @@ import Foundation
3333

3434
struct HTTPResponseEncoder {
3535

36-
static func makeHeaderLines(from response: HTTPResponse, for request: HTTPRequest) -> [String] {
36+
static func makeHeaderLines(from response: HTTPResponse) -> [String] {
3737
let status = [response.version.rawValue,
3838
String(response.statusCode.code),
3939
response.statusCode.phrase].joined(separator: " ")
4040

4141
var httpHeaders = response.headers
4242

43-
if request.shouldKeepAlive {
44-
httpHeaders[.connection] = "keep-alive"
45-
}
46-
4743
if response.body.isEmpty {
48-
httpHeaders[.contentLength] = nil
44+
httpHeaders[.contentLength] = String(0)
4945
} else {
5046
httpHeaders[.contentLength] = String(response.body.count)
5147
}
@@ -55,13 +51,15 @@ struct HTTPResponseEncoder {
5551
return [status] + headers + ["\r\n"]
5652
}
5753

58-
static func encodeResponse(_ response: HTTPResponse, for request: HTTPRequest) throws -> Data {
59-
guard var data = makeHeaderLines(from: response, for: request)
54+
static func encodeResponse(_ response: HTTPResponse) throws -> Data {
55+
guard var data = makeHeaderLines(from: response)
6056
.joined(separator: "\r\n")
6157
.data(using: .utf8) else {
6258
throw Error("Invalid Response Headers")
6359
}
60+
6461
data.append(response.body)
62+
6563
return data
6664
}
6765
}

Sources/HTTPServer.swift

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ import Foundation
3434
public final actor HTTPServer {
3535

3636
private let port: UInt16
37+
private let timeout: TimeInterval
3738
private var socket: AsyncSocket?
3839
private var connections = [HTTPConnection: Task<Void, Never>]()
3940
private var handlers: [(route: HTTPRoute, handler: HTTPHandler)]
4041

41-
public init(port: UInt16, handlers: [(route: HTTPRoute, handler: HTTPHandler)] = []) {
42+
public init(port: UInt16, timeout: TimeInterval = 15, handlers: [(route: HTTPRoute, handler: HTTPHandler)] = []) {
4243
self.port = port
44+
self.timeout = timeout
4345
self.handlers = []
4446
}
4547

@@ -95,11 +97,12 @@ public final actor HTTPServer {
9597
connections[connection] = Task {
9698
do {
9799
for try await request in connection.requests {
98-
let response = await self.handleRequest(request)
99-
try connection.sendResponse(response, for: request)
100+
let response = await handleRequest(request)
101+
try connection.sendResponse(response)
102+
guard response.shouldKeepAlive else { break }
100103
}
101104
} catch {
102-
print("error", connection.hostname, error)
105+
print("connection error", connection.hostname, error)
103106
}
104107
removeConnection(connection)
105108
}
@@ -114,13 +117,24 @@ public final actor HTTPServer {
114117
}
115118

116119
func handleRequest(_ request: HTTPRequest) async -> HTTPResponse {
120+
var response = await handleRequest(request, timeout: timeout)
121+
if request.shouldKeepAlive {
122+
response.headers[.connection] = request.headers[.connection]
123+
}
124+
return response
125+
}
126+
127+
func handleRequest(_ request: HTTPRequest, timeout: TimeInterval) async -> HTTPResponse {
117128
guard let handler = handlers.first(where: { $0.route ~= request })?.handler else {
118129
return HTTPResponse(statusCode: .notFound)
119130
}
120131

121132
do {
122-
return try await handler.handleRequest(request)
133+
return try await withThrowingTimeout(seconds: timeout) {
134+
try await handler.handleRequest(request)
135+
}
123136
} catch {
137+
print("handler error", error)
124138
return HTTPResponse(statusCode: .serverError)
125139
}
126140
}

Sources/Socket/Socket.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ struct Socket: Hashable {
8989

9090
if let address = listenAddress {
9191
guard address.withCString({ cstring in inet_pton(AF_INET6, cstring, &addr.sin6_addr) }) == 1 else {
92-
print("\(address) is not converted.")
9392
throw SocketError.bindFailed(makeErrorMessage())
9493
}
9594
}

Sources/Task+Timeout.swift

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//
2+
// Task+Timeout.swift
3+
// FlyingFox
4+
//
5+
// Created by Simon Whitty on 15/02/2022.
6+
// Copyright © 2022 Simon Whitty. All rights reserved.
7+
//
8+
// Distributed under the permissive MIT license
9+
// Get the latest version from here:
10+
//
11+
// https://github.com/swhitty/FlyingFox
12+
//
13+
// Permission is hereby granted, free of charge, to any person obtaining a copy
14+
// of this software and associated documentation files (the "Software"), to deal
15+
// in the Software without restriction, including without limitation the rights
16+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17+
// copies of the Software, and to permit persons to whom the Software is
18+
// furnished to do so, subject to the following conditions:
19+
//
20+
// The above copyright notice and this permission notice shall be included in all
21+
// copies or substantial portions of the Software.
22+
//
23+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29+
// SOFTWARE.
30+
//
31+
32+
import Foundation
33+
34+
func withThrowingTimeout<T>(seconds: TimeInterval, body: @escaping @Sendable () async throws -> T) async throws -> T {
35+
try await withThrowingTaskGroup(of: T.self) { group -> T in
36+
group.addTask(operation: {
37+
try await body()
38+
})
39+
group.addTask {
40+
try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000))
41+
throw TimeoutError()
42+
}
43+
let success = try await group.next()!
44+
group.cancelAll()
45+
return success
46+
}
47+
}
48+
49+
struct TimeoutError: LocalizedError {
50+
var errorDescription: String? = "Timed out before completion"
51+
}
52+
53+
extension Task where Failure == Error {
54+
55+
// Start a new Task with a timeout.
56+
init(priority: TaskPriority? = nil, timeout: TimeInterval, operation: @escaping @Sendable () async throws -> Success) {
57+
self = Task(priority: priority) {
58+
try await withThrowingTimeout(seconds: timeout, body: operation)
59+
}
60+
}
61+
}

Tests/Task+TimeoutTests.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//
2+
// Task+TimeoutTests.swift
3+
// FlyingFox
4+
//
5+
// Created by Simon Whitty on 15/02/2022.
6+
// Copyright © 2022 Simon Whitty. All rights reserved.
7+
//
8+
// Distributed under the permissive MIT license
9+
// Get the latest version from here:
10+
//
11+
// https://github.com/swhitty/FlyingFox
12+
//
13+
// Permission is hereby granted, free of charge, to any person obtaining a copy
14+
// of this software and associated documentation files (the "Software"), to deal
15+
// in the Software without restriction, including without limitation the rights
16+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17+
// copies of the Software, and to permit persons to whom the Software is
18+
// furnished to do so, subject to the following conditions:
19+
//
20+
// The above copyright notice and this permission notice shall be included in all
21+
// copies or substantial portions of the Software.
22+
//
23+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29+
// SOFTWARE.
30+
//
31+
32+
@testable import FlyingFox
33+
import XCTest
34+
35+
final class TaskTimeoutTests: XCTestCase {
36+
37+
func testTimeoutReturnsSuccess_WhenTimeoutDoesNotExpire() async throws {
38+
// given
39+
let value = try await Task(timeout: 0.5) {
40+
"Fish"
41+
}.value
42+
43+
// then
44+
XCTAssertEqual(value, "Fish")
45+
}
46+
47+
func testTimeoutThrowsError_WhenTimeoutExpires() async {
48+
// given
49+
let task = Task<Void, Error>(timeout: 0.5) {
50+
try? await Task.sleep(nanoseconds: 10_000_000_000)
51+
}
52+
53+
// then
54+
do {
55+
_ = try await task.value
56+
XCTFail("Expected TimeoutError")
57+
} catch {
58+
XCTAssertTrue(error is TimeoutError)
59+
}
60+
}
61+
62+
63+
func testTimeoutCancels() async {
64+
// given
65+
let task = Task(timeout: 0.5) {
66+
try? await Task.sleep(nanoseconds: 10_000_000_000)
67+
}
68+
69+
// when
70+
task.cancel()
71+
72+
// then
73+
do {
74+
_ = try await task.value
75+
XCTFail("Expected CancellationError")
76+
} catch {
77+
XCTAssertTrue(error is CancellationError)
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)