Skip to content

Commit

Permalink
More flow operations: flatMap, flatten, flattenPar, groupBy (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Nov 28, 2024
1 parent 3508bdd commit 055e19d
Show file tree
Hide file tree
Showing 21 changed files with 994 additions and 170 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Flow.iterate(0)(_ + 1) // natural numbers
.map(_ + 1)
.intersperse(5)
// compute the running total
.mapStateful(() => 0) { (state, value) =>
.mapStateful(0) { (state, value) =>
val newState = state + value
(newState, newState)
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/ox/channels/BufferCapacity.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ox.channels

/** Used to determine the capacity of buffers, when new channels are created by channel or flow-transforming operations, such as
* [[Source.map]], [[Flow.async]], [[Flow.runToChannel]]. If not in scope, the default of 16 is used.
* [[Source.map]], [[Flow.buffer]], [[Flow.runToChannel]]. If not in scope, the default of 16 is used.
*/
opaque type BufferCapacity = Int

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/ox/channels/select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def selectOrClosed(
* The result returned by the selected clause, wrapped with [[SelectResult]], or a [[ChannelClosed]], when any of the channels is closed
* (done or in error).
*/
def selectOrClosed[T](clauses: List[SelectClause[T]]): SelectResult[T] | ChannelClosed =
def selectOrClosed[T](clauses: Seq[SelectClause[T]]): SelectResult[T] | ChannelClosed =
ChannelClosed.fromJoxOrT(JSelect.selectOrClosed(clauses.map(_.delegate)*))

//
Expand Down Expand Up @@ -105,7 +105,7 @@ def select(
* @throws ChannelClosedException
* When any of the channels is closed (done or in error).
*/
def select[T](clauses: List[SelectClause[T]]): SelectResult[T] = selectOrClosed(clauses).orThrow
def select[T](clauses: Seq[SelectClause[T]]): SelectResult[T] = selectOrClosed(clauses).orThrow

//

Expand Down Expand Up @@ -164,7 +164,7 @@ def selectOrClosed[T1, T2, T3, T4, T5](
* @return
* The value received from the selected source, or a [[ChannelClosed]], when any of the channels is closed (done or in error).
*/
def selectOrClosed[T](sources: List[Source[T]])(using DummyImplicit): T | ChannelClosed =
def selectOrClosed[T](sources: Seq[Source[T]])(using DummyImplicit): T | ChannelClosed =
selectOrClosed(sources.map(_.receiveClause: SelectClause[T])) match
case r: Source[T]#Received => r.value
case c: ChannelClosed => c
Expand Down Expand Up @@ -225,5 +225,5 @@ def select[T1, T2, T3, T4, T5](
* @throws ChannelClosedException
* When any of the channels is closed (done or in error).
*/
def select[T](sources: List[Source[T]])(using DummyImplicit): T | ChannelClosed =
def select[T](sources: Seq[Source[T]])(using DummyImplicit): T | ChannelClosed =
selectOrClosed(sources).orThrow
2 changes: 1 addition & 1 deletion core/src/main/scala/ox/flow/Flow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.annotation.nowarn
*
* Flows can be created using the [[Flow.usingSink]], [[Flow.fromValues]] and other `Flow.from*` methods, [[Flow.tick]] etc.
*
* Transformation stages can be added using the available combinators, such as [[Flow.map]], [[Flow.async]], [[Flow.grouped]], etc. Each
* Transformation stages can be added using the available combinators, such as [[Flow.map]], [[Flow.buffer]], [[Flow.grouped]], etc. Each
* such method returns a new immutable `Flow` instance.
*
* Running a flow is possible using one of the `run*` methods, such as [[Flow.runToList]], [[Flow.runToChannel]] or [[Flow.runFold]].
Expand Down
130 changes: 110 additions & 20 deletions core/src/main/scala/ox/flow/FlowOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import ox.CancellableFork
import ox.Fork
import ox.Ox
import ox.OxUnsupervised
import ox.channels.BufferCapacity
import ox.channels.Channel
import ox.channels.ChannelClosed
import ox.channels.Default
import ox.channels.Sink
import ox.channels.Source
import ox.channels.BufferCapacity
import ox.channels.forkPropagate
import ox.channels.selectOrClosed
import ox.discard
import ox.flow.internal.groupByImpl
import ox.forkCancellable
import ox.forkUnsupervised
import ox.forkUser
import ox.repeatWhile
import ox.sleep
import ox.supervised
Expand All @@ -24,7 +26,6 @@ import ox.unsupervised
import java.util.concurrent.Semaphore
import scala.concurrent.duration.DurationLong
import scala.concurrent.duration.FiniteDuration
import ox.forkUser

class FlowOps[+T]:
outer: Flow[T] =>
Expand All @@ -36,7 +37,7 @@ class FlowOps[+T]:
*
* Any exceptions are propagated by the returned flow.
*/
def async()(using BufferCapacity): Flow[T] = Flow.usingEmitInline: emit =>
def buffer()(using BufferCapacity): Flow[T] = Flow.usingEmitInline: emit =>
val ch = BufferCapacity.newChannel[T]
unsupervised:
runLastToChannelAsync(ch)
Expand Down Expand Up @@ -88,6 +89,20 @@ class FlowOps[+T]:
f(t); t
)

/** Applies the given mapping function `f` to each element emitted by this flow, obtaining a nested flow to run. The elements emitted by
* the nested flow are then emitted by the returned flow.
*
* The nested flows are run in sequence, that is, the next nested flow is started only after the previous one completes.
*
* @param f
* The mapping function.
*/
def flatMap[U](f: T => Flow[U]): Flow[U] = Flow.usingEmitInline: emit =>
last.run(
FlowEmit.fromInline: t =>
f(t).runToEmit(emit)
)

/** Intersperses elements emitted by this flow with `inject` elements. The `inject` element is emitted between each pair of elements. */
def intersperse[U >: T](inject: U): Flow[U] = intersperse(None, inject, None)

Expand Down Expand Up @@ -288,30 +303,72 @@ class FlowOps[+T]:
case e: ChannelClosed.Error => throw e.toThrowable
case r: U @unchecked => emit(r); true

/** Pipes the elements of child flows into the output source. If the parent source or any of the child sources emit an error, the pulling
* stops and the output source emits the error.
/** Given that this flow emits other flows, flattens the nested flows into a single flow. The resulting flow emits elements from the
* nested flows in the order they are emitted.
*
* The nested flows are run in sequence, that is, the next nested flow is started only after the previous one completes.
*/
def flatten[U](using T <:< Flow[U]): Flow[U] = this.flatMap(identity)

/** Pipes the elements of child flows into the returned flow.
*
* Runs all flows concurrently in the background. The size of the buffers is determined by the [[BufferCapacity]] that is in scope.
* If the this flow or any of the child flows emit an error, the pulling stops and the output flow propagates the error.
*
* Up to [[parallelism]] child flows are run concurrently in the background. When the limit is reached, until a child flow completes, no
* more child flows are run.
*
* The size of the buffers for the elements emitted by the child flows is determined by the [[BufferCapacity]] that is in scope.
*
* @param parallelism
* An upper bound on the number of child flows that run in parallel.
*/
def flatten[U](using T <:< Flow[U])(using BufferCapacity): Flow[U] = Flow.usingEmitInline: emit =>
def flattenPar[U](parallelism: Int)(using T <:< Flow[U])(using BufferCapacity): Flow[U] = Flow.usingEmitInline: emit =>
case class Nested(child: Flow[U])
case object ChildDone

unsupervised:
val childStream = outer.map(Nested(_)).runToChannel()
var pool = List[Source[Nested] | Source[U]](childStream)
val childOutputChannel = BufferCapacity.newChannel[U]
val childDoneChannel = Channel.unlimited[ChildDone.type]

// When an error occurs in the parent, propagating it also to `childOutputChannel`, from which we always
// `select` in the main loop. That way, even if max parallelism is reached, errors in the parent will
// be discovered without delay.
val parentChannel = outer.map(Nested(_)).onError(childOutputChannel.error(_).discard).runToChannel()

var runningChannelCount = 1 // parent is running
var parentDone = false

while runningChannelCount > 0 do
assert(runningChannelCount <= parallelism + 1)

val pool: List[Source[Nested] | Source[U] | Source[ChildDone.type]] =
// +1, because of the parent channel.
if runningChannelCount == parallelism + 1 || parentDone then List(childOutputChannel, childDoneChannel)
else List(childOutputChannel, childDoneChannel, parentChannel)

repeatWhile:
selectOrClosed(pool) match
// Only `parentChannel` might be done, child completion is signalled via `childDoneChannel`.
case ChannelClosed.Done =>
// TODO: optimization idea: find a way to remove the specific channel that signalled to be Done
pool = pool.filterNot(_.isClosedForReceiveDetail.contains(ChannelClosed.Done))
if pool.isEmpty then false
else true
parentDone = isSourceDone(parentChannel)
assert(parentDone)

runningChannelCount -= 1

case e: ChannelClosed.Error => throw e.toThrowable

case ChildDone => runningChannelCount -= 1

case Nested(t) =>
pool = t.runToChannel() :: pool
true
case r: U @unchecked => emit(r); true
forkUnsupervised:
t.onDone(childDoneChannel.send(ChildDone)).runPipeToSink(childOutputChannel, propagateDone = false)
.discard

runningChannelCount += 1

case u: U @unchecked => emit(u)
end match
end while
end flattenPar

/** Concatenates this flow with the `other` flow. The resulting flow will emit elements from this flow first, and then from the `other`
* flow.
Expand Down Expand Up @@ -420,7 +477,7 @@ class FlowOps[+T]:
* A function that transforms the final state into an optional element emitted by the returned flow. By default the final state is
* ignored.
*/
def mapStateful[S, U](initializeState: () => S)(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None): Flow[U] =
def mapStateful[S, U](initializeState: => S)(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None): Flow[U] =
def resultToSome(s: S, t: T) =
val (newState, result) = f(s, t)
(newState, Some(result))
Expand Down Expand Up @@ -448,9 +505,9 @@ class FlowOps[+T]:
* ignored.
*/
def mapStatefulConcat[S, U](
initializeState: () => S
initializeState: => S
)(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None): Flow[U] = Flow.usingEmitInline: emit =>
var state = initializeState()
var state = initializeState
last.run(
FlowEmit.fromInline: t =>
val (nextState, result) = f(state, t)
Expand Down Expand Up @@ -702,6 +759,37 @@ class FlowOps[+T]:
}.tapException(other.errorOrClosed(_).discard)
end alsoToTap

/** Groups elements emitted by this flow into child flows. Elements for which [[predicate]] returns the same value (of type `V`) end up in
* the same child flow. [[childFlowTransform]] is applied to each created child flow, and the resulting flow is run in the background.
* Finally, the child flows are merged back, that is any elements that they emit are emitted by the returned flow.
*
* Up to [[parallelism]] child flows are run concurrently in the background. When the limit is reached, the child flow which didn't
* receive a new element the longest is completed as done.
*
* Child flows for `V` values might be created multiple times (if, after completing a child flow because of parallelism limit, new
* elements arrive, mapped to a given `V` value). However, it is guaranteed that for a given `V` value, there will be at most one child
* flow running at any time.
*
* Child flows should only complete as done when the flow of received `T` elements completes. Otherwise, the entire stream will fail with
* an error.
*
* Errors that occur in this flow, or in any child flows, become errors of the returned flow (exceptions are wrapped in
* [[ChannelClosedException]]).
*
* The size of the buffers for the elements emitted by this flow (which is also run in the background) and the child flows are determined
* by the [[BufferCapacity]] that is in scope.
*
* @param parallelism
* An upper bound on the number of child flows that run in parallel at any time.
* @param predicate
* Function used to determine the group for an element of type `T`. Each group is represented by a value of type `V`.
* @param childFlowTransform
* The function that is used to create a child flow, which is later in the background. The arguments are the group value, for which the
* flow is created, and a flow of `T` elements in that group (each such element has the same group value `V` returned by `predicated`).
*/
def groupBy[V, U](parallelism: Int, predicate: T => V)(childFlowTransform: V => Flow[T] => Flow[U])(using BufferCapacity): Flow[U] =
groupByImpl(outer, parallelism, predicate)(childFlowTransform)

/** Discard all elements emitted by this flow. The returned flow completes only when this flow completes (successfully or with an error).
*/
def drain(): Flow[Nothing] = Flow.usingEmitInline: emit =>
Expand Down Expand Up @@ -729,3 +817,5 @@ class FlowOps[+T]:
ch.done()
}.discard
end FlowOps

private[flow] inline def isSourceDone(ch: Source[?]) = ch.isClosedForReceiveDetail.contains(ChannelClosed.Done)
13 changes: 10 additions & 3 deletions core/src/main/scala/ox/flow/FlowRunOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ox.channels.BufferCapacity
import ox.discard

import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal

trait FlowRunOps[+T]:
this: Flow[T] =>
Expand Down Expand Up @@ -44,11 +45,17 @@ trait FlowRunOps[+T]:

/** Passes each element emitted by this flow to the given sink. Blocks until the flow completes.
*
* Errors are always propagated. Successful flow completion is propagated when `propagateDone` is set to `true`.
* Errors are always propagated to the provided sink. Successful flow completion is propagated when `propagateDone` is set to `true`.
*
* Fatal errors are rethrown.
*/
def runPipeToSink(sink: Sink[T], propagateDone: Boolean): Unit =
last.run(FlowEmit.fromInline(t => sink.send(t)))
if propagateDone then sink.doneOrClosed().discard
try
last.run(FlowEmit.fromInline(t => sink.send(t)))
if propagateDone then sink.doneOrClosed().discard
catch
case NonFatal(e) => sink.error(e)
case t => sink.error(t); throw t

/** Ignores all elements emitted by the flow. Blocks until the flow completes. */
def runDrain(): Unit = runForeach(_ => ())
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/ox/flow/FlowTextOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ trait FlowTextOps[+T]:
def lines(charset: Charset)(using T <:< Chunk[Byte]): Flow[String] =
// buffer == null is a special state for handling empty chunks in onComplete, in order to tell them apart from empty lines
outer
.mapStatefulConcat(() => null: Chunk[Byte])(
.mapStatefulConcat(null: Chunk[Byte])(
{ case (buffer, nextChunk) =>
@tailrec
def splitChunksAtNewLine(buf: Chunk[Byte], chunk: Chunk[Byte], acc: Vector[Chunk[Byte]]): (Chunk[Byte], Vector[Chunk[Byte]]) =
Expand Down
Loading

0 comments on commit 055e19d

Please sign in to comment.