diff --git a/Sources/AuthFoundation/OAuth2/OAuth2Client.swift b/Sources/AuthFoundation/OAuth2/OAuth2Client.swift index ee8974be5..81c189728 100644 --- a/Sources/AuthFoundation/OAuth2/OAuth2Client.swift +++ b/Sources/AuthFoundation/OAuth2/OAuth2Client.swift @@ -117,7 +117,7 @@ public final class OAuth2Client { /// If this value has recently been retrieved, the cached result is returned. /// - Parameter completion: Completion block invoked with the result. public func openIdConfiguration(completion: @escaping (Result) -> Void) { - configurationQueue.sync { + configurationLock.withLock { if let openIdConfiguration = openIdConfiguration { completion(.success(openIdConfiguration)) } else { @@ -161,14 +161,14 @@ public final class OAuth2Client { /// - token: Token to refresh. /// - completion: Completion bock invoked with the result. public func refresh(_ token: Token, completion: @escaping (Result) -> Void) { - guard let clientSettings = token.context.clientSettings, - token.refreshToken != nil - else { - completion(.failure(.missingToken(type: .refreshToken))) - return - } - - refreshQueue.sync { + refreshLock.withLock { + guard let clientSettings = token.context.clientSettings, + token.refreshToken != nil + else { + completion(.failure(.missingToken(type: .refreshToken))) + return + } + guard token.refreshAction == nil else { token.refreshAction?.add(completion) return @@ -180,11 +180,6 @@ public final class OAuth2Client { } } - private(set) lazy var refreshQueue: DispatchQueue = { - DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")", - qos: .userInitiated, - attributes: .concurrent) - }() private func performRefresh(token: Token, clientSettings: [String: String]) { guard let action = token.refreshAction else { return } @@ -524,6 +519,14 @@ public final class OAuth2Client { // MARK: Private properties / methods private let delegates = DelegateCollection() + private let refreshLock = UnfairLock() + private(set) lazy var refreshQueue: DispatchQueue = { + DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")", + qos: .userInitiated, + attributes: .concurrent) + }() + + private let configurationLock = UnfairLock() private lazy var configurationQueue: DispatchQueue = { DispatchQueue(label: "com.okta.configurationQueue.\(baseURL.host ?? "unknown")", qos: .userInitiated, diff --git a/Sources/AuthFoundation/Utilities/UnsafeLock.swift b/Sources/AuthFoundation/Utilities/UnsafeLock.swift new file mode 100644 index 000000000..f6c09b52b --- /dev/null +++ b/Sources/AuthFoundation/Utilities/UnsafeLock.swift @@ -0,0 +1,41 @@ +// +// Copyright (c) 2023-Present, Okta, Inc. and/or its affiliates. All rights reserved. +// The Okta software accompanied by this notice is provided pursuant to the Apache License, Version 2.0 (the "License.") +// +// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// +// See the License for the specific language governing permissions and limitations under the License. +// + +import Foundation + +// **Note:** It would be preferable to use OSAllocatedUnfairLock for this, but this would mean dropping support for older OS versions. While this approach is safe, OSAllocatedUnfairLock provides more features we might need in the future. +// +// If the minimum supported version of this SDK is to increase in the future, this class should be removed and replaced with OSAllocatedUnfairLock. +final class UnfairLock: NSLocking { + private let _lock: UnsafeMutablePointer = { + let result = UnsafeMutablePointer.allocate(capacity: 1) + result.initialize(to: os_unfair_lock()) + return result + }() + + deinit { + _lock.deinitialize(count: 1) + _lock.deallocate() + } + + func lock() { + os_unfair_lock_lock(_lock) + } + + func tryLock() -> Bool { + os_unfair_lock_trylock(_lock) + } + + func unlock() { + os_unfair_lock_unlock(_lock) + } +} diff --git a/Tests/AuthFoundationTests/APIRetryTests.swift b/Tests/AuthFoundationTests/APIRetryTests.swift index 967441f5d..8a9162021 100644 --- a/Tests/AuthFoundationTests/APIRetryTests.swift +++ b/Tests/AuthFoundationTests/APIRetryTests.swift @@ -76,9 +76,9 @@ class APIRetryTests: XCTestCase { func testCustomRetryCount() throws { client = MockApiClient(configuration: configuration, - session: urlSession, - baseURL: baseUrl, - shouldRetry: .retry(maximumCount: 5)) + session: urlSession, + baseURL: baseUrl, + shouldRetry: .retry(maximumCount: 5)) try performRetryRequest(count: 6) XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-Count"], "5") XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-For"], requestId) diff --git a/Tests/AuthFoundationTests/CredentialRefreshTests.swift b/Tests/AuthFoundationTests/CredentialRefreshTests.swift index 91d7af9f4..6aa2c3f7e 100644 --- a/Tests/AuthFoundationTests/CredentialRefreshTests.swift +++ b/Tests/AuthFoundationTests/CredentialRefreshTests.swift @@ -344,26 +344,5 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate { try await credential.refreshIfNeeded(graceInterval: 300) } } - - func perform(queueCount: Int = 5, iterationCount: Int = 10, _ block: @escaping () async throws -> Void) rethrows { - let queues: [DispatchQueue] = (0..=5.5.1) + @available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *) + func testOpenIDConfigurationAsync() async throws { + urlSession.expect("https://example.com/.well-known/openid-configuration", + data: try data(from: .module, for: "openid-configuration", in: "MockResponses"), + contentType: "application/json") + + try perform { + let config = try await self.client.openIdConfiguration() + XCTAssertEqual(config.authorizationEndpoint.absoluteString, + "https://example.com/oauth2/v1/authorize") + } + } + #endif + func testJWKS() throws { urlSession.expect("https://example.com/.well-known/openid-configuration", data: try data(from: .module, for: "openid-configuration", in: "MockResponses"), diff --git a/Tests/TestCommon/XCTestCase+Extensions.swift b/Tests/TestCommon/XCTestCase+Extensions.swift index a636e6479..b2f3dc649 100644 --- a/Tests/TestCommon/XCTestCase+Extensions.swift +++ b/Tests/TestCommon/XCTestCase+Extensions.swift @@ -75,4 +75,28 @@ public extension XCTestCase { let jsonData = data(for: json) return try decoder.decode(T.self, from: jsonData) } + + #if swift(>=5.5.1) + @available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *) + func perform(queueCount: Int = 5, iterationCount: Int = 10, _ block: @escaping () async throws -> Void) rethrows { + let queues: [DispatchQueue] = (0..