Skip to content

Commit

Permalink
Use a proper lock to gate access to concurrent operations (like refre…
Browse files Browse the repository at this point in the history
…sh or openIdConfiguration) instead of relying on a queue
  • Loading branch information
mikenachbaur-okta committed Feb 16, 2024
1 parent bb44b85 commit 7e6d4c0
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 38 deletions.
31 changes: 17 additions & 14 deletions Sources/AuthFoundation/OAuth2/OAuth2Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public final class OAuth2Client {
/// If this value has recently been retrieved, the cached result is returned.
/// - Parameter completion: Completion block invoked with the result.
public func openIdConfiguration(completion: @escaping (Result<OpenIdConfiguration, OAuth2Error>) -> Void) {
configurationQueue.sync {
configurationLock.withLock {
if let openIdConfiguration = openIdConfiguration {
completion(.success(openIdConfiguration))
} else {
Expand Down Expand Up @@ -161,14 +161,14 @@ public final class OAuth2Client {
/// - token: Token to refresh.
/// - completion: Completion bock invoked with the result.
public func refresh(_ token: Token, completion: @escaping (Result<Token, OAuth2Error>) -> Void) {
guard let clientSettings = token.context.clientSettings,
token.refreshToken != nil
else {
completion(.failure(.missingToken(type: .refreshToken)))
return
}

refreshQueue.sync {
refreshLock.withLock {
guard let clientSettings = token.context.clientSettings,
token.refreshToken != nil
else {
completion(.failure(.missingToken(type: .refreshToken)))
return
}
guard token.refreshAction == nil else {
token.refreshAction?.add(completion)
return
Expand All @@ -180,11 +180,6 @@ public final class OAuth2Client {
}
}

private(set) lazy var refreshQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
attributes: .concurrent)
}()
private func performRefresh(token: Token, clientSettings: [String: String]) {
guard let action = token.refreshAction else { return }

Expand Down Expand Up @@ -524,6 +519,14 @@ public final class OAuth2Client {
// MARK: Private properties / methods
private let delegates = DelegateCollection<OAuth2ClientDelegate>()

private let refreshLock = UnfairLock()
private(set) lazy var refreshQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
attributes: .concurrent)
}()

private let configurationLock = UnfairLock()
private lazy var configurationQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.configurationQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
Expand Down
41 changes: 41 additions & 0 deletions Sources/AuthFoundation/Utilities/UnsafeLock.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// Copyright (c) 2023-Present, Okta, Inc. and/or its affiliates. All rights reserved.
// The Okta software accompanied by this notice is provided pursuant to the Apache License, Version 2.0 (the "License.")
//
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
// See the License for the specific language governing permissions and limitations under the License.
//

import Foundation

// **Note:** It would be preferable to use OSAllocatedUnfairLock for this, but this would mean dropping support for older OS versions. While this approach is safe, OSAllocatedUnfairLock provides more features we might need in the future.
//
// If the minimum supported version of this SDK is to increase in the future, this class should be removed and replaced with OSAllocatedUnfairLock.
final class UnfairLock: NSLocking {
private let _lock: UnsafeMutablePointer<os_unfair_lock> = {
let result = UnsafeMutablePointer<os_unfair_lock>.allocate(capacity: 1)
result.initialize(to: os_unfair_lock())
return result
}()

deinit {
_lock.deinitialize(count: 1)
_lock.deallocate()
}

func lock() {
os_unfair_lock_lock(_lock)
}

func tryLock() -> Bool {
os_unfair_lock_trylock(_lock)
}

func unlock() {
os_unfair_lock_unlock(_lock)
}
}
6 changes: 3 additions & 3 deletions Tests/AuthFoundationTests/APIRetryTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class APIRetryTests: XCTestCase {

func testCustomRetryCount() throws {
client = MockApiClient(configuration: configuration,
session: urlSession,
baseURL: baseUrl,
shouldRetry: .retry(maximumCount: 5))
session: urlSession,
baseURL: baseUrl,
shouldRetry: .retry(maximumCount: 5))
try performRetryRequest(count: 6)
XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-Count"], "5")
XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-For"], requestId)
Expand Down
21 changes: 0 additions & 21 deletions Tests/AuthFoundationTests/CredentialRefreshTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -344,26 +344,5 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate {
try await credential.refreshIfNeeded(graceInterval: 300)
}
}

func perform(queueCount: Int = 5, iterationCount: Int = 10, _ block: @escaping () async throws -> Void) rethrows {
let queues: [DispatchQueue] = (0..<queueCount).map { queueNumber in
DispatchQueue(label: "Async queue \(queueNumber)")
}

let group = DispatchGroup()
for queue in queues {
for _ in 0..<iterationCount {
queue.async {
group.enter()
Task {
try await block()
group.leave()
}
}
}
}

_ = group.wait(timeout: .short)
}
#endif
}
15 changes: 15 additions & 0 deletions Tests/AuthFoundationTests/OAuth2ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ final class OAuth2ClientTests: XCTestCase {
"https://example.com/oauth2/v1/authorize")
}

#if swift(>=5.5.1)
@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testOpenIDConfigurationAsync() async throws {
urlSession.expect("https://example.com/.well-known/openid-configuration",
data: try data(from: .module, for: "openid-configuration", in: "MockResponses"),
contentType: "application/json")

try perform {
let config = try await self.client.openIdConfiguration()
XCTAssertEqual(config.authorizationEndpoint.absoluteString,
"https://example.com/oauth2/v1/authorize")
}
}
#endif

func testJWKS() throws {
urlSession.expect("https://example.com/.well-known/openid-configuration",
data: try data(from: .module, for: "openid-configuration", in: "MockResponses"),
Expand Down
24 changes: 24 additions & 0 deletions Tests/TestCommon/XCTestCase+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,28 @@ public extension XCTestCase {
let jsonData = data(for: json)
return try decoder.decode(T.self, from: jsonData)
}

#if swift(>=5.5.1)
@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func perform(queueCount: Int = 5, iterationCount: Int = 10, _ block: @escaping () async throws -> Void) rethrows {
let queues: [DispatchQueue] = (0..<queueCount).map { queueNumber in
DispatchQueue(label: "Async queue \(queueNumber)")
}

let group = DispatchGroup()
for queue in queues {
for _ in 0..<iterationCount {
queue.async {
group.enter()
Task {
try await block()
group.leave()
}
}
}
}

_ = group.wait(timeout: .short)
}
#endif
}

0 comments on commit 7e6d4c0

Please sign in to comment.