From cbea98d59edb4f0664448220d482c958096260fa Mon Sep 17 00:00:00 2001 From: Mike Nachbaur Date: Wed, 5 Jun 2024 14:58:31 -0700 Subject: [PATCH] Test that token refreshes merge the results correctly --- .../AuthFoundation/JWT/Protocols/Claim.swift | 2 +- .../CredentialRefreshTests.swift | 41 +++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/Sources/AuthFoundation/JWT/Protocols/Claim.swift b/Sources/AuthFoundation/JWT/Protocols/Claim.swift index 7f409bcc5..a49784c6e 100644 --- a/Sources/AuthFoundation/JWT/Protocols/Claim.swift +++ b/Sources/AuthFoundation/JWT/Protocols/Claim.swift @@ -35,7 +35,7 @@ public protocol HasClaims { /// Unlike the ``claims`` property, this returns values as strings. var customClaims: [String] { get } - /// Raw paylaod of claims, as a dictionary representation. + /// Raw payload of claims, as a dictionary representation. var payload: [String: Any] { get } } diff --git a/Tests/AuthFoundationTests/CredentialRefreshTests.swift b/Tests/AuthFoundationTests/CredentialRefreshTests.swift index 6aa2c3f7e..aff3a56ae 100644 --- a/Tests/AuthFoundationTests/CredentialRefreshTests.swift +++ b/Tests/AuthFoundationTests/CredentialRefreshTests.swift @@ -44,7 +44,7 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate { case none case error case openIdOnly - case refresh(count: Int) + case refresh(count: Int, rotate: Bool = false) } func credential(for token: Token, expectAPICalls: APICalls = .refresh(count: 1), expiresIn: TimeInterval = 3600) throws -> Credential { @@ -61,11 +61,11 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate { data: try data(from: .module, for: "openid-configuration", in: "MockResponses"), contentType: "application/json") - case .refresh(let count): + case .refresh(let count, let rotate): urlSession.expect("https://example.com/.well-known/openid-configuration", data: try data(from: .module, for: "openid-configuration", in: "MockResponses"), contentType: "application/json") - for _ in 1 ... count { + for index in 1 ... count { urlSession.expect("https://example.com/oauth2/v1/token", data: data(for: """ { @@ -73,7 +73,7 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate { "expires_in": \(expiresIn), "access_token": "\(String.mockAccessToken)", "scope": "openid profile offline_access", - "refresh_token": "therefreshtoken", + "refresh_token": "therefreshtoken\(rotate ? "-\(index)" : "")", "id_token": "\(String.mockIdToken)" } """)) @@ -310,6 +310,39 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate { XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Bearer \(credential.token.accessToken)") } + + func testRotatingRefreshTokens() throws { + let credential = try credential(for: Token.mockToken(expiresIn: 1), + expectAPICalls: .refresh(count: 3, rotate: true), + expiresIn: 1) + + // Initial refresh token + XCTAssertEqual(credential.token.refreshToken, "abc123") + + // First refresh + var refreshExpectation = expectation(description: "First refresh") + credential.refresh { _ in + refreshExpectation.fulfill() + } + wait(for: [refreshExpectation], timeout: .standard) + XCTAssertEqual(credential.token.refreshToken, "therefreshtoken-1") + + // Second refresh + refreshExpectation = expectation(description: "Second refresh") + credential.refresh { _ in + refreshExpectation.fulfill() + } + wait(for: [refreshExpectation], timeout: .standard) + XCTAssertEqual(credential.token.refreshToken, "therefreshtoken-2") + + // Third refresh + refreshExpectation = expectation(description: "Third refresh") + credential.refresh { _ in + refreshExpectation.fulfill() + } + wait(for: [refreshExpectation], timeout: .standard) + XCTAssertEqual(credential.token.refreshToken, "therefreshtoken-3") + } #if swift(>=5.5.1) @available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)