Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uncancellable reactive-streams StreamSubscriber #3446

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import cats.effect.std.Random

import java.nio.ByteBuffer
import java.util.concurrent.Flow.{Publisher, Subscriber, Subscription}
import scala.concurrent.duration._
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.duration.*

class SubscriberStabilitySpec extends Fs2Suite {
val attempts = 100
Expand Down Expand Up @@ -67,4 +68,32 @@ class SubscriberStabilitySpec extends Fs2Suite {
.replicateA_(attempts)
}
}

test("StreamSubscriber cancels subscription on downstream cancellation") {
def makePublisher(
requestCalled: AtomicBoolean,
subscriptionCancelled: AtomicBoolean
): Publisher[ByteBuffer] =
new Publisher[ByteBuffer] {

class SubscriptionImpl extends Subscription {
override def request(n: Long): Unit = requestCalled.set(true)
override def cancel(): Unit = subscriptionCancelled.set(true)
}

override def subscribe(s: Subscriber[? >: ByteBuffer]): Unit =
s.onSubscribe(new SubscriptionImpl)
}

for {
requestCalled <- IO(new AtomicBoolean(false))
subscriptionCancelled <- IO(new AtomicBoolean(false))
publisher = makePublisher(requestCalled, subscriptionCancelled)
_ <- fromPublisher[IO](publisher, chunkSize = 1)
.interruptWhen(Stream.eval(IO(requestCalled.get())).repeat.spaced(10.millis))
.compile
.drain
_ <- IO(subscriptionCancelled.get).assert
} yield ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,11 @@ object StreamSubscriber {
def onComplete(): Unit = nextState(OnComplete)
def onFinalize: F[Unit] = F.delay(nextState(OnFinalize))
def dequeue1: F[Either[Throwable, Option[Chunk[A]]]] =
F.async_[Either[Throwable, Option[Chunk[A]]]] { cb =>
nextState(OnDequeue(out => cb(Right(out))))
F.async[Either[Throwable, Option[Chunk[A]]]] { cb =>
F.delay {
nextState(OnDequeue(out => cb(Right(out))))
Some(F.unit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the cleanup of the subscription handled around this effect? Producing Some(F.unit) basically says that there is no cleanup action which must be taken on the state resulting from nextState, and it actually corresponds to a memory leak unless we have a higher level guarantee of cleanup.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cleanup happens with the onFinalize call in the Streams.bracket here:

def stream(subscribe: F[Unit])(implicit ev: ApplicativeError[F, Throwable]): Stream[F, A] =
      Stream.bracket(subscribe)(_ => onFinalize) >> Stream
        .eval(dequeue1)
        .repeat
        .rethrow
        .unNoneTerminate
        .unchunks

which is the only place where dequeue1 is used internally.

Although technically I guess someone could call StreamSubscriber#sub.dequeue1 directly, since it is exposed here (unlike analogous flow interop).

}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ package fs2
package interop
package reactivestreams

import cats.effect._
import cats.effect.*
import cats.effect.std.Random
import org.reactivestreams._

import scala.concurrent.duration._
import org.reactivestreams.*

import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.duration.*

class SubscriberStabilitySpec extends Fs2Suite {
test("StreamSubscriber has no race condition") {
Expand Down Expand Up @@ -87,4 +87,32 @@ class SubscriberStabilitySpec extends Fs2Suite {
if (failed)
fail("Uncaught exception was reported")
}

test("StreamSubscriber cancels subscription on downstream cancellation") {
def makePublisher(
requestCalled: AtomicBoolean,
subscriptionCancelled: AtomicBoolean
): Publisher[ByteBuffer] =
new Publisher[ByteBuffer] {

class SubscriptionImpl extends Subscription {
override def request(n: Long): Unit = requestCalled.set(true)
override def cancel(): Unit = subscriptionCancelled.set(true)
}

override def subscribe(s: Subscriber[? >: ByteBuffer]): Unit =
s.onSubscribe(new SubscriptionImpl)
}

for {
requestCalled <- IO(new AtomicBoolean(false))
subscriptionCancelled <- IO(new AtomicBoolean(false))
publisher = makePublisher(requestCalled, subscriptionCancelled)
_ <- fromPublisher[IO, ByteBuffer](publisher, bufferSize = 1)
.interruptWhen(Stream.eval(IO(requestCalled.get())).repeat.spaced(10.millis))
.compile
.drain
_ <- IO(subscriptionCancelled.get).assert
} yield ()
}
}