Skip to content

Commit

Permalink
add initialValue to CurrentValue subjects
Browse files Browse the repository at this point in the history
  • Loading branch information
BrentMifsud committed Aug 20, 2024
1 parent b93dba9 commit c8558c4
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ public struct AsyncCurrentValueSubject<Element: Sendable>: AsyncSequence, Sendab
private let storage: _Storage

/// A shared AsyncSequence that yields its current value and any value changes to its subscribers
public init() {
storage = _Storage()
/// - Parameter value: the initial value
public init(initialValue value: Element) {
storage = _Storage(initialValue: value)
}

init(storage: _Storage) {
Expand All @@ -42,10 +43,14 @@ public struct AsyncCurrentValueSubject<Element: Sendable>: AsyncSequence, Sendab

extension AsyncCurrentValueSubject {
actor _Storage {
private(set) var currentValue: Element?
private(set) var currentValue: Element
private(set) var finished: Bool = false
private(set) var continuations: [UUID: AsyncStream<Element>.Continuation] = [:]

init(initialValue: Element) {
currentValue = initialValue
}

deinit {
for id in continuations.keys {
continuations[id]?.finish()
Expand Down Expand Up @@ -95,10 +100,7 @@ extension AsyncCurrentValueSubject {
}

continuations[id] = continuation

if let currentValue {
continuation.yield(currentValue)
}
continuation.yield(currentValue)
}

private func removeContinuation(id: UUID) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ public struct AsyncThrowingCurrentValueSubject<Element: Sendable>: AsyncSequence
private let storage: _Storage

/// A shared `AsyncSequence` that yields value changes to its subscribers
public init() {
storage = _Storage()
/// - Parameter value: the initial value
public init(initialValue value: Element) {
storage = _Storage(initialValue: value)
}

init(storage: _Storage) {
Expand Down Expand Up @@ -48,6 +49,10 @@ extension AsyncThrowingCurrentValueSubject {
private(set) var failure: (any Error)?
private(set) var continuations: [UUID: AsyncThrowingStream<Element, any Error>.Continuation] = [:]

init(initialValue: Element) {
currentValue = initialValue
}

deinit {
for id in continuations.keys {
continuations[id]?.finish()
Expand Down
101 changes: 58 additions & 43 deletions Tests/AsyncSubjectsTests/AsyncCurrentValueSubjectTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,43 @@ import Testing

@Suite("Current Value Subject Tests")
struct AsyncCurrentValueSubjectTests {
@Test("Initial value", arguments: [1])
func initialValue(expectedValue: Int) async throws {
let storage = AsyncCurrentValueSubject<Int>._Storage(initialValue: expectedValue)
let subject = AsyncCurrentValueSubject(storage: storage)

await storage.validateInitialState()

let task = Task {
var recievedValues = [Int]()
for await value in subject {
recievedValues.append(value)
}
#expect(recievedValues == [expectedValue])
}

await waitFor { await !storage.continuations.isEmpty }
await subject.finish()
await task.value
await storage.validateEndState()
}

@Test("(Single subscriber) Test emitted Values", arguments: [[1, 2, 3]])
func valuesAreValid(expectedValues: [Int]) async throws {
let storage = AsyncCurrentValueSubject<Int>._Storage()
let storage = AsyncCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
group.addTask {
var emittedValues = [Int]()

#expect(await storage.currentValue == nil)

for await value in subject {
emittedValues.append(value)
}

#expect(emittedValues == expectedValues)
#expect(emittedValues == [0] + expectedValues)
}

group.addTask {
Expand All @@ -46,24 +67,26 @@ struct AsyncCurrentValueSubjectTests {

#expect(await storage.finished)
}

await storage.validateEndState()
}

@Test("(Multiple subscribers) Test emitted values", arguments: [[1, 2, 3]])
func valuesAreValidWithMultipleSubscribers(expectedValues: [Int]) async throws {
let storage = AsyncCurrentValueSubject<Int>._Storage()
let storage = AsyncCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
let createSubscriber = { @Sendable () async in
var emittedValues = [Int]()

#expect(await storage.currentValue == nil)

for await value in subject {
emittedValues.append(value)
}

#expect(emittedValues == expectedValues)
#expect(emittedValues == [0] + expectedValues)
}

for _ in 0..<3 {
Expand All @@ -84,10 +107,9 @@ struct AsyncCurrentValueSubjectTests {
}

await group.waitForAll()

#expect(await storage.currentValue == expectedValues.last)
#expect(await storage.finished)
}

await storage.validateEndState()
}

@Test(
Expand All @@ -101,17 +123,19 @@ struct AsyncCurrentValueSubjectTests {
)
func valuesAreValidStaggedSubscribers(arguments: ([Int], [Int])) async throws {
let (values, expectedValues) = arguments
let storage = AsyncCurrentValueSubject<Int>._Storage()
let storage = AsyncCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
group.addTask {
var recievedValues = [Int]()
for await value in subject {
recievedValues.append(value)
}

#expect(recievedValues == values)
#expect(recievedValues == [0] + values)
}

group.addTask {
Expand Down Expand Up @@ -143,6 +167,8 @@ struct AsyncCurrentValueSubjectTests {

await group.waitForAll()
}

await storage.validateEndState()
}

@Test(
Expand All @@ -157,24 +183,24 @@ struct AsyncCurrentValueSubjectTests {
)
func valuesAreValidAfterThrow(arguments: ([Int], [Int], NSError)) async throws {
let (values, expected, failure) = arguments
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage()
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncThrowingCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
group.addTask {
var emittedValues = [Int]()

#expect(await storage.currentValue == nil)

do {
for try await value in subject {
emittedValues.append(value)
}
} catch {
#expect(emittedValues == expected)
#expect((error as NSError) == failure)
return
}

#expect(emittedValues == [0] + expected)
}

group.addTask {
Expand All @@ -192,11 +218,9 @@ struct AsyncCurrentValueSubjectTests {
}

await group.waitForAll()

#expect(await storage.currentValue == expected.last)
#expect(await storage.finished)
#expect((await storage.failure as? NSError) == failure)
}

await storage.validateEndState()
}

@Test(
Expand All @@ -211,27 +235,24 @@ struct AsyncCurrentValueSubjectTests {
)
func valuesAreValidAfterThrowMultipleSubscribers(arguments: ([Int], [Int], NSError)) async throws {
let (values, expected, failure) = arguments
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage()
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncThrowingCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
let createSubscriber = { @Sendable () async in
var emittedValues = [Int]()

#expect(await storage.currentValue == nil)
#expect(await !storage.finished)
#expect(await storage.failure == nil)

do {
for try await value in subject {
emittedValues.append(value)
}
} catch {
#expect(emittedValues == expected)
#expect((error as NSError) == failure)
}

#expect(emittedValues == expected)
#expect(emittedValues == [0] + expected)
}

for _ in 0..<3 {
Expand All @@ -255,13 +276,9 @@ struct AsyncCurrentValueSubjectTests {
}

await group.waitForAll()

#expect(
await storage.currentValue == expected.last, "Storage should have the final value before the failure")
#expect(await storage.continuations.isEmpty, "Continuations should be freed up after completion")
#expect(await (storage.failure as? NSError) == failure, "Failure should not be null")
#expect(await storage.finished, "the subject should be finished")
}

await storage.validateEndState()
}

@Test(
Expand All @@ -276,9 +293,11 @@ struct AsyncCurrentValueSubjectTests {
)
func failedSubscriberDoesNotreceiveCurrentValue(arguments: ([Int], [Int], NSError)) async throws {
let (values, expectedValues, failure) = arguments
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage()
let storage = AsyncThrowingCurrentValueSubject<Int>._Storage(initialValue: 0)
let subject = AsyncThrowingCurrentValueSubject<Int>(storage: storage)

await storage.validateInitialState()

await withTaskGroup(of: Void.self) { group in
group.addTask {
var recievedValues = [Int]()
Expand All @@ -291,7 +310,7 @@ struct AsyncCurrentValueSubjectTests {
#expect((error as NSError) == failure)
}

#expect(recievedValues == expectedValues)
#expect(recievedValues == [0] + expectedValues)
}

group.addTask {
Expand Down Expand Up @@ -324,23 +343,19 @@ struct AsyncCurrentValueSubjectTests {
}

await group.waitForAll()

#expect(await storage.continuations.isEmpty, "Continuations should be freed up after completion")
#expect(await (storage.failure as? NSError) == failure, "Failure should not be null")
#expect(await storage.finished, "the subject should be finished")
}

await storage.validateEndState()
}
}

extension AsyncCurrentValueSubject._Storage {
func validateInitialState() {
#expect(currentValue == nil)
#expect(!finished, "Initial state should not be finished")
#expect(continuations.isEmpty, "Initial state should not have continuations")
}

func validateEndState() {
#expect(currentValue != nil)
#expect(finished, "Final state should be finished")
#expect(continuations.isEmpty, "Final state should not have any continuations")
}
Expand Down
Loading

0 comments on commit c8558c4

Please sign in to comment.