diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 9fd98ef5c4..49da1f592b 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -29,7 +29,7 @@ import cats.data.Ior import cats.effect.Concurrent import cats.effect.kernel._ import cats.effect.kernel.implicits._ -import cats.effect.std.{Console, Queue, QueueSink, QueueSource, Semaphore} +import cats.effect.std.{Console, CountDownLatch, Queue, QueueSink, QueueSource, Semaphore} import cats.effect.Resource.ExitCase import cats.syntax.all._ import fs2.compat._ @@ -231,37 +231,33 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, * 2. chunks from each pipe come out of the resulting stream in the same * order as they came out of the pipe, and without skipping any chunk. */ - def broadcastThrough[F2[x] >: F[x]: Concurrent, O2]( - pipes: Pipe[F2, O, O2]* - ): Stream[F2, O2] = { + def broadcastThrough[F2[x] >: F[x]: Concurrent, O2](pipes: Pipe[F2, O, O2]*): Stream[F2, O2] = { assert(pipes.nonEmpty, s"pipes should not be empty") - Stream - .eval { - ( - cats.effect.std.CountDownLatch[F2](pipes.length), - fs2.concurrent.Topic[F2, Chunk[O]] - ).tupled - } - .flatMap { case (latch, topic) => - def produce = chunks.through(topic.publish) - - def consume(pipe: Pipe[F2, O, O2]): Pipe[F2, Chunk[O], O2] = - _.unchunks.through(pipe) - - Stream(pipes: _*) - .map { pipe => - Stream - .resource(topic.subscribeAwait(1)) - .flatMap { sub => - // crucial that awaiting on the latch is not passed to - // the pipe, so that the pipe cannot interrupt it and alter - // the latch count - Stream.exec(latch.release >> latch.await) ++ sub.through(consume(pipe)) - } + Stream.force { + for { + // topic: contains the chunk that the pipes are processing at one point. + // until and unless all pipes are finished with it, won't move to next one + topic <- Topic[F2, Chunk[O]] + // Coordination: neither the producer nor any consumer starts + // until and unless all consumers are subscribed to topic. + allReady <- CountDownLatch[F2](pipes.length) + } yield { + val checkIn = allReady.release >> allReady.await + + def dump(pipe: Pipe[F2, O, O2]): Stream[F2, O2] = + Stream.resource(topic.subscribeAwait(1)).flatMap { sub => + // Wait until all pipes are ready before consuming. + // Crucial: checkin is not passed to the pipe, + // so pipe cannot interrupt it and alter the latch count + Stream.exec(checkIn) ++ pipe(sub.unchunks) } - .parJoinUnbounded - .concurrently(Stream.eval(latch.await) ++ produce) + + val dumpAll: Stream[F2, O2] = Stream(pipes: _*).map(dump).parJoinUnbounded + // Wait until all pipes are checked in before pulling + val pump = Stream.exec(allReady.await) ++ topic.publish(chunks) + dumpAll.concurrently(pump) } + } } /** Behaves like the identity function, but requests `n` elements at a time from the input.