Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrent accesses to the async refreshIfNeeded function #173

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions Sources/AuthFoundation/OAuth2/OAuth2Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ public final class OAuth2Client {
/// If this value has recently been retrieved, the cached result is returned.
/// - Parameter completion: Completion block invoked with the result.
public func openIdConfiguration(completion: @escaping (Result<OpenIdConfiguration, OAuth2Error>) -> Void) {
configurationQueue.sync {
configurationLock.withLock {
if let openIdConfiguration = openIdConfiguration {
completion(.success(openIdConfiguration))
configurationQueue.async {
completion(.success(openIdConfiguration))
}
} else {
guard openIdConfigurationAction == nil else {
openIdConfigurationAction?.add(completion)
Expand Down Expand Up @@ -161,14 +163,14 @@ public final class OAuth2Client {
/// - token: Token to refresh.
/// - completion: Completion bock invoked with the result.
public func refresh(_ token: Token, completion: @escaping (Result<Token, OAuth2Error>) -> Void) {
guard let clientSettings = token.context.clientSettings,
token.refreshToken != nil
else {
completion(.failure(.missingToken(type: .refreshToken)))
return
}

refreshQueue.sync {
refreshLock.withLock {
guard let clientSettings = token.context.clientSettings,
token.refreshToken != nil
else {
completion(.failure(.missingToken(type: .refreshToken)))
return
}
guard token.refreshAction == nil else {
token.refreshAction?.add(completion)
return
Expand All @@ -180,11 +182,6 @@ public final class OAuth2Client {
}
}

private(set) lazy var refreshQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
attributes: .concurrent)
}()
private func performRefresh(token: Token, clientSettings: [String: String]) {
guard let action = token.refreshAction else { return }

Expand Down Expand Up @@ -524,6 +521,14 @@ public final class OAuth2Client {
// MARK: Private properties / methods
private let delegates = DelegateCollection<OAuth2ClientDelegate>()

private let refreshLock = UnfairLock()
private(set) lazy var refreshQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.refreshQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
attributes: .concurrent)
}()

private let configurationLock = UnfairLock()
private lazy var configurationQueue: DispatchQueue = {
DispatchQueue(label: "com.okta.configurationQueue.\(baseURL.host ?? "unknown")",
qos: .userInitiated,
Expand Down
41 changes: 41 additions & 0 deletions Sources/AuthFoundation/Utilities/UnsafeLock.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// Copyright (c) 2023-Present, Okta, Inc. and/or its affiliates. All rights reserved.
// The Okta software accompanied by this notice is provided pursuant to the Apache License, Version 2.0 (the "License.")
//
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
// See the License for the specific language governing permissions and limitations under the License.
//

import Foundation

// **Note:** It would be preferable to use OSAllocatedUnfairLock for this, but this would mean dropping support for older OS versions. While this approach is safe, OSAllocatedUnfairLock provides more features we might need in the future.
//
// If the minimum supported version of this SDK is to increase in the future, this class should be removed and replaced with OSAllocatedUnfairLock.
final class UnfairLock: NSLocking {
private let _lock: UnsafeMutablePointer<os_unfair_lock> = {
let result = UnsafeMutablePointer<os_unfair_lock>.allocate(capacity: 1)
result.initialize(to: os_unfair_lock())
return result
}()

deinit {
_lock.deinitialize(count: 1)
_lock.deallocate()
}

func lock() {
os_unfair_lock_lock(_lock)
}

func tryLock() -> Bool {
os_unfair_lock_trylock(_lock)
}

func unlock() {
os_unfair_lock_unlock(_lock)
}
}
6 changes: 3 additions & 3 deletions Tests/AuthFoundationTests/APIRetryTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class APIRetryTests: XCTestCase {

func testCustomRetryCount() throws {
client = MockApiClient(configuration: configuration,
session: urlSession,
baseURL: baseUrl,
shouldRetry: .retry(maximumCount: 5))
session: urlSession,
baseURL: baseUrl,
shouldRetry: .retry(maximumCount: 5))
try performRetryRequest(count: 6)
XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-Count"], "5")
XCTAssertEqual(client.request?.allHTTPHeaderFields?["X-Okta-Retry-For"], requestId)
Expand Down
18 changes: 13 additions & 5 deletions Tests/AuthFoundationTests/CredentialRefreshTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -315,26 +315,34 @@ final class CredentialRefreshTests: XCTestCase, OAuth2ClientDelegate {
@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testRefreshAsync() async throws {
let credential = try credential(for: Token.simpleMockToken)
try await credential.refresh()
try perform {
try await credential.refresh()
}
}

@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testRefreshIfNeededExpiredAsync() async throws {
let credential = try credential(for: Token.mockToken(issuedOffset: 6000))
try await credential.refreshIfNeeded(graceInterval: 300)
try perform {
try await credential.refreshIfNeeded(graceInterval: 300)
}
}

@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testRefreshIfNeededWithinGraceIntervalAsync() async throws {
let credential = try credential(for: Token.mockToken(issuedOffset: 0),
expectAPICalls: .none)
try await credential.refreshIfNeeded(graceInterval: 300)
try perform {
try await credential.refreshIfNeeded(graceInterval: 300)
}
}

@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testRefreshIfNeededOutsideGraceIntervalAsync() async throws {
let credential = try credential(for: Token.mockToken(issuedOffset: 3500))
try await credential.refreshIfNeeded(graceInterval: 300)
let credential = try credential(for: Token.mockToken(issuedOffset: 3500))
try perform {
try await credential.refreshIfNeeded(graceInterval: 300)
}
}
#endif
}
16 changes: 16 additions & 0 deletions Tests/AuthFoundationTests/OAuth2ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,22 @@ final class OAuth2ClientTests: XCTestCase {
"https://example.com/oauth2/v1/authorize")
}

#if swift(>=5.5.1)
@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func testOpenIDConfigurationAsync() async throws {
urlSession.expect("https://example.com/.well-known/openid-configuration",
data: try data(from: .module, for: "openid-configuration", in: "MockResponses"),
contentType: "application/json")

let client = try XCTUnwrap(self.client)
try perform {
let config = try await client.openIdConfiguration()
XCTAssertEqual(config.authorizationEndpoint.absoluteString,
"https://example.com/oauth2/v1/authorize")
}
}
#endif

func testJWKS() throws {
urlSession.expect("https://example.com/.well-known/openid-configuration",
data: try data(from: .module, for: "openid-configuration", in: "MockResponses"),
Expand Down
27 changes: 27 additions & 0 deletions Tests/TestCommon/TimeInterval+Extensions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// Copyright (c) 2023-Present, Okta, Inc. and/or its affiliates. All rights reserved.
// The Okta software accompanied by this notice is provided pursuant to the Apache License, Version 2.0 (the "License.")
//
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
// See the License for the specific language governing permissions and limitations under the License.
//

import Foundation

public extension TimeInterval {
static let standard: Self = 3
static let short: Self = 1
static let long: Self = 5
static let veryLong: Self = 10
}

public extension DispatchTime {
static var standard: Self { .now() + .standard }
static var short: Self { .now() + .short }
static var long: Self { .now() + .long }
static var veryLong: Self { .now() + .veryLong }
}
21 changes: 13 additions & 8 deletions Tests/TestCommon/URLSessionMock.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import FoundationNetworking

class URLSessionMock: URLSessionProtocol {
var configuration: URLSessionConfiguration = .ephemeral
let queue = DispatchQueue(label: "URLSessionMock")

struct Call {
let url: String
Expand All @@ -38,7 +39,9 @@ class URLSessionMock: URLSessionProtocol {

private(set) var expectedCalls: [Call] = []
func expect(call: Call) {
expectedCalls.append(call)
queue.sync {
expectedCalls.append(call)
}
}

func expect(_ url: String,
Expand All @@ -61,14 +64,16 @@ class URLSessionMock: URLSessionProtocol {
}

func call(for url: String) -> Call? {
guard let index = expectedCalls.firstIndex(where: { call in
call.url == url
}) else {
XCTFail("Mock URL \(url) not found")
return nil
queue.sync {
guard let index = expectedCalls.firstIndex(where: { call in
call.url == url
}) else {
XCTFail("Mock URL \(url) not found")
return nil
}

return expectedCalls.remove(at: index)
}

return expectedCalls.remove(at: index)
}

func dataTaskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol {
Expand Down
24 changes: 24 additions & 0 deletions Tests/TestCommon/XCTestCase+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,28 @@ public extension XCTestCase {
let jsonData = data(for: json)
return try decoder.decode(T.self, from: jsonData)
}

#if swift(>=5.5.1)
@available(iOS 13.0, tvOS 13.0, macOS 10.15, watchOS 6, *)
func perform(queueCount: Int = 5, iterationCount: Int = 10, _ block: @escaping () async throws -> Void) rethrows {
let queues: [DispatchQueue] = (0..<queueCount).map { queueNumber in
DispatchQueue(label: "Async queue \(queueNumber)")
}

let group = DispatchGroup()
for queue in queues {
for _ in 0..<iterationCount {
queue.async {
group.enter()
Task {
try await block()
group.leave()
}
}
}
}

_ = group.wait(timeout: .short)
}
#endif
}
Loading