diff --git a/addons/akka-stream/src/main/scala/com/spingo/op_rabbit/stream/MessagePublisherSink.scala b/addons/akka-stream/src/main/scala/com/spingo/op_rabbit/stream/MessagePublisherSink.scala index a7da6fb..4fd0d32 100644 --- a/addons/akka-stream/src/main/scala/com/spingo/op_rabbit/stream/MessagePublisherSink.scala +++ b/addons/akka-stream/src/main/scala/com/spingo/op_rabbit/stream/MessagePublisherSink.scala @@ -1,98 +1,99 @@ -package com.spingo.op_rabbit -package stream +package com.spingo.op_rabbit.stream -import akka.actor.{ActorRef,Props} +import com.spingo.op_rabbit._ +import com.spingo.op_rabbit.Message._ +import akka.stream.stage.GraphStage +import akka.actor.{ActorRef, Props} import akka.actor.FSM import akka.pattern.ask import akka.stream.scaladsl.Sink -import akka.stream.actor._ import scala.concurrent.{Future, Promise} import scala.concurrent.duration._ import com.timcharper.acked.AckedSink -import scala.util.{Try,Success,Failure} +import scala.util.{Try, Success,Failure} +import akka.stream._ +import akka.stream.stage.GraphStageLogic +import akka.stream.stage.GraphStageWithMaterializedValue +import akka.stream.stage.InHandler +import akka.stream.scaladsl.Flow +import akka.util.Timeout -private [stream] object MessagePublisherSinkActor { - sealed trait State - case object Running extends State - case object Stopping extends State - case object AllDoneFuturePlease -} -private class MessagePublisherSinkActor(rabbitControl: ActorRef, timeoutAfter: FiniteDuration, qos: Int) extends ActorSubscriber with FSM[MessagePublisherSinkActor.State, Unit] { - import ActorSubscriberMessage._ - import MessagePublisherSinkActor._ +private class MessagePublisherSink(rabbitControl: ActorRef, timeoutAfter: FiniteDuration, qos: Int) extends GraphStageWithMaterializedValue[SinkShape[(Promise[Unit],Message)], Future[Unit]] { + val in = Inlet[(Promise[Unit],Message)]("MessagePublisherSink.in") - private val queue = scala.collection.mutable.Map.empty[Long, Promise[Unit]] - private val completed = Promise[Unit] + val shape = SinkShape.of(in) - startWith(Running, ()) + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Unit]) = { + val completed = Promise[Unit]() - override val requestStrategy = new MaxInFlightRequestStrategy(max = qos) { - override def inFlightInternally: Int = queue.size - } + val logic = new GraphStageLogic(shape) { + private val queue = scala.collection.mutable.Map.empty[Long, Promise[Unit]] - override def postRestart(reason: Throwable): Unit = { - stopWith(Failure(reason)) - super.postRestart(reason) - } + // callback to schedule the rabbitControl responses into the stage + private val futureCallback = getAsyncCallback[Try[Message.ConfirmResponse]]({ + case Success(Message.Ack(id)) => + queue.remove(id).get.success(()) + pullIfNeeded() - private def stopWith(reason: Try[Unit]): Unit = { - context stop self - completed.tryComplete(reason) - } + case Success(Message.Nack(id)) => + queue.remove(id).get.failure(new MessageNacked(id)) + pullIfNeeded() - when(Running) { - case Event(response: Message.ConfirmResponse, _) => - handleResponse(response) - stay + case Success(Message.Fail(id, exception: Throwable)) => + queue.remove(id).get.failure(exception) + pullIfNeeded() - case Event(OnError(e), _) => - completed.tryFailure(e) - goto(Stopping) + case Failure(exception) => + // currently fails the stream - maybe better just fail the message - needs additional context + fail(exception) + }) - case Event(OnComplete, _) => - goto(Stopping) - } + override def preStart(): Unit = { + // we must ensure we can acknowledge messages even on stream complete + setKeepGoing(true) + pull(in) + } + + setHandler(in, new InHandler { + override def onPush(): Unit = { + val (promise, msg) = grab(in) + queue(msg.id) = promise - when(Stopping) { - case Event(response: Message.ConfirmResponse, _) => - handleResponse(response) - if(queue.isEmpty) - stop - else - stay - } + val eventualResult = rabbitControl.ask(msg)(Timeout(timeoutAfter)).mapTo[ConfirmResponse] - whenUnhandled { - case Event(OnNext((p: Promise[Unit] @unchecked, msg: Message)), _) => - queue(msg.id) = p - rabbitControl ! msg - stay + // TODO: which EC to schedule the callback onto? + eventualResult.onComplete(futureCallback.invoke)(materializer.executionContext) - case Event(MessagePublisherSinkActor.AllDoneFuturePlease,_) => - sender ! completed.future - stay - } + pullIfNeeded() + } - onTransition { - case Running -> Stopping if queue.isEmpty => - stopWith(Success(())) - } + override def onUpstreamFinish(): Unit = { + if (queue.isEmpty) complete() + } - onTermination { - case e: StopEvent => - stopWith(Success(())) - } + override def onUpstreamFailure(ex: Throwable): Unit = { + fail(ex) + } + }) + + private def pullIfNeeded(): Unit = { + if (isClosed(in) && queue.isEmpty) complete() + else if (queue.size < qos && !hasBeenPulled(in)) tryPull(in) + } - private val handleResponse: Message.ConfirmResponse => Unit = { - case Message.Ack(id) => - queue.remove(id).get.success(()) + private def complete(): Unit = { + completed.success(()) + completeStage() + } - case Message.Nack(id) => - queue.remove(id).get.failure(new MessageNacked(id)) + private def fail(ex: Throwable): Unit = { + completed.failure(ex) + failStage(ex) + } + } - case Message.Fail(id, exception: Throwable) => - queue.remove(id).get.failure(exception) + (logic, completed.future) } } @@ -130,12 +131,6 @@ object MessagePublisherSink { @param timeoutAfter The duration for which we'll wait for a message to be acked; note, timeouts and non-acknowledged messages will cause the upstream elements to fail. The sink will not throw an exception. */ def apply(rabbitControl: ActorRef, timeoutAfter: FiniteDuration = 30 seconds, qos: Int = 8): AckedSink[Message, Future[Unit]] = AckedSink { - Sink.actorSubscriber[(Promise[Unit], Message)](Props(new MessagePublisherSinkActor(rabbitControl, timeoutAfter, qos))). - mapMaterializedValue { subscriber => - implicit val akkaTimeout = akka.util.Timeout(timeoutAfter) - implicit val ec = SameThreadExecutionContext - - (subscriber ? MessagePublisherSinkActor.AllDoneFuturePlease).mapTo[Future[Unit]].flatMap(identity) - } + new MessagePublisherSink(rabbitControl, timeoutAfter, qos) } } diff --git a/build.sbt b/build.sbt index 0c1cb08..f932ed4 100644 --- a/build.sbt +++ b/build.sbt @@ -2,8 +2,8 @@ import java.util.Properties val json4sVersion = "3.6.6" val circeVersion = "0.12.3" -val akkaVersion = "2.5.25" -val playVersion = "2.7.4" +val akkaVersion = "2.6.6" +val playVersion = "2.9.0" val appProperties = { val prop = new Properties()