From 299c6d515d1906bfca449fc616a98f83157504bf Mon Sep 17 00:00:00 2001 From: Caio Zullo Date: Wed, 13 Mar 2024 11:28:17 +0200 Subject: [PATCH] Wait for URLProtocol to start loading requests before cancelling the URLSessionTasks when running tests to achieve predictable test results and prevent tasks leaking outside the test scope (which could affect other tests). --- .../Helpers/URLProtocolStub.swift | 16 +++++++++++++--- .../URLSessionHTTPClientTests.swift | 16 +++++++--------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift b/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift index f96a591f..ea80c6c5 100644 --- a/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift +++ b/EssentialFeed/EssentialFeedTests/Shared API Infra/Helpers/URLProtocolStub.swift @@ -9,6 +9,7 @@ class URLProtocolStub: URLProtocol { let data: Data? let response: URLResponse? let error: Error? + let shouldCancelTask: Bool let requestObserver: ((URLRequest) -> Void)? } @@ -21,11 +22,15 @@ class URLProtocolStub: URLProtocol { private static let queue = DispatchQueue(label: "URLProtocolStub.queue") static func stub(data: Data?, response: URLResponse?, error: Error?) { - stub = Stub(data: data, response: response, error: error, requestObserver: nil) + stub = Stub(data: data, response: response, error: error, shouldCancelTask: false, requestObserver: nil) } - static func observeRequests(observer: @escaping (URLRequest) -> Void) { - stub = Stub(data: nil, response: nil, error: nil, requestObserver: observer) + static func cancelIncomingTasks() { + stub = Stub(data: nil, response: nil, error: nil, shouldCancelTask: true, requestObserver: nil) + } + + static func observeRequests(shouldFinish: Bool = true, observer: @escaping (URLRequest) -> Void) { + stub = Stub(data: nil, response: nil, error: nil, shouldCancelTask: false, requestObserver: observer) } static func removeStub() { @@ -43,6 +48,11 @@ class URLProtocolStub: URLProtocol { override func startLoading() { guard let stub = URLProtocolStub.stub else { return } + if stub.shouldCancelTask { + task?.cancel() + return + } + if let data = stub.data { client?.urlProtocol(self, didLoad: data) } diff --git a/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift b/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift index 84bc38eb..13e08475 100644 --- a/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift +++ b/EssentialFeed/EssentialFeedTests/Shared API Infra/URLSessionHTTPClientTests.swift @@ -29,11 +29,9 @@ class URLSessionHTTPClientTests: XCTestCase { } func test_cancelGetFromURLTask_cancelsURLRequest() { - let exp = expectation(description: "Wait for request") - URLProtocolStub.observeRequests { _ in exp.fulfill() } + URLProtocolStub.cancelIncomingTasks() - let receivedError = resultErrorFor(taskHandler: { $0.cancel() }) as NSError? - wait(for: [exp], timeout: 1.0) + let receivedError = resultErrorFor() as NSError? XCTAssertEqual(receivedError?.code, URLError.cancelled.rawValue) } @@ -104,8 +102,8 @@ class URLSessionHTTPClientTests: XCTestCase { } } - private func resultErrorFor(_ values: (data: Data?, response: URLResponse?, error: Error?)? = nil, taskHandler: (HTTPClientTask) -> Void = { _ in }, file: StaticString = #filePath, line: UInt = #line) -> Error? { - let result = resultFor(values, taskHandler: taskHandler, file: file, line: line) + private func resultErrorFor(_ values: (data: Data?, response: URLResponse?, error: Error?)? = nil, file: StaticString = #filePath, line: UInt = #line) -> Error? { + let result = resultFor(values, file: file, line: line) switch result { case let .failure(error): @@ -116,17 +114,17 @@ class URLSessionHTTPClientTests: XCTestCase { } } - private func resultFor(_ values: (data: Data?, response: URLResponse?, error: Error?)?, taskHandler: (HTTPClientTask) -> Void = { _ in }, file: StaticString = #filePath, line: UInt = #line) -> HTTPClient.Result { + private func resultFor(_ values: (data: Data?, response: URLResponse?, error: Error?)?, file: StaticString = #filePath, line: UInt = #line) -> HTTPClient.Result { values.map { URLProtocolStub.stub(data: $0, response: $1, error: $2) } let sut = makeSUT(file: file, line: line) let exp = expectation(description: "Wait for completion") var receivedResult: HTTPClient.Result! - taskHandler(sut.get(from: anyURL()) { result in + sut.get(from: anyURL()) { result in receivedResult = result exp.fulfill() - }) + } wait(for: [exp], timeout: 1.0) return receivedResult