diff --git a/modules/echo/src/main/scala/almond/echo/EchoKernel.scala b/modules/echo/src/main/scala/almond/echo/EchoKernel.scala index 035af382c..fdb6e4c4a 100644 --- a/modules/echo/src/main/scala/almond/echo/EchoKernel.scala +++ b/modules/echo/src/main/scala/almond/echo/EchoKernel.scala @@ -53,7 +53,14 @@ object EchoKernel extends CaseApp[Options] { log.debug("Running kernel") Kernel.create(new EchoInterpreter, interpreterEc, kernelThreads, logCtx) - .flatMap(_.runOnConnectionFile(connectionFile, "echo", zeromqThreads, Nil, autoClose = true)) + .flatMap(_.runOnConnectionFile( + connectionFile, + "echo", + zeromqThreads, + Nil, + autoClose = true, + timeout = options.lingerDuration + )) .unsafeRunSync()(IORuntime.global) } } diff --git a/modules/echo/src/main/scala/almond/echo/Options.scala b/modules/echo/src/main/scala/almond/echo/Options.scala index 11724fbdd..4ec4c7bf0 100644 --- a/modules/echo/src/main/scala/almond/echo/Options.scala +++ b/modules/echo/src/main/scala/almond/echo/Options.scala @@ -4,6 +4,8 @@ import almond.kernel.install.{Options => InstallOptions} import caseapp.{HelpMessage, Recurse} import caseapp.core.help.Help import caseapp.core.parser.Parser +import caseapp.Hidden +import scala.concurrent.duration.{Duration, DurationInt} final case class Options( connectionFile: Option[String] = None, @@ -11,8 +13,20 @@ final case class Options( log: String = "warn", install: Boolean = false, @Recurse - installOptions: InstallOptions = InstallOptions() -) + installOptions: InstallOptions = InstallOptions(), + @HelpMessage( + """Time given to the client to accept ZeroMQ messages before exiting. Parsed with scala.concurrent.duration.Duration, this accepts things like "Inf" or "5 seconds"""" + ) + @Hidden + linger: Option[String] = None +) { + + lazy val lingerDuration = linger + .map(_.trim) + .filter(_.nonEmpty) + .map(Duration(_)) + .getOrElse(5.seconds) +} object Options { implicit lazy val parser: Parser[Options] = Parser.derive diff --git a/modules/scala/integration/src/main/scala/almond/integration/KernelLauncher.scala b/modules/scala/integration/src/main/scala/almond/integration/KernelLauncher.scala index 586c5e2f6..db395d9e7 100644 --- a/modules/scala/integration/src/main/scala/almond/integration/KernelLauncher.scala +++ b/modules/scala/integration/src/main/scala/almond/integration/KernelLauncher.scala @@ -275,9 +275,11 @@ class KernelLauncher( } def close(): Unit = { - conn.close(partial = false).unsafeRunTimed(2.minutes)(IORuntime.global).getOrElse { - sys.error("Timeout when closing ZeroMQ connections") - } + conn.close(partial = false, timeout = 30.seconds) + .unsafeRunTimed(2.minutes)(IORuntime.global) + .getOrElse { + sys.error("Timeout when closing ZeroMQ connections") + } if (perTestZeroMqContext) { val t = stackTracePrinterThread(output) diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala b/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala index 2bb9deb9d..a51b62432 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/Launcher.scala @@ -26,6 +26,7 @@ import java.nio.channels.ClosedSelectorException import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal +import scala.concurrent.duration.Duration object Launcher extends CaseApp[LauncherOptions] { @@ -325,7 +326,8 @@ object Launcher extends CaseApp[LauncherOptions] { "scala", zeromqThreads, Nil, - autoClose = false + autoClose = false, + timeout = Duration.Inf // unused here )) .unsafeRunSync()(IORuntime.global) val leftoverMessages: Seq[(Channel, RawMessage)] = run.unsafeRunSync()(IORuntime.global) @@ -410,7 +412,9 @@ object Launcher extends CaseApp[LauncherOptions] { for (outputHandler <- outputHandlerOpt) outputHandler.done() - try conn.close(partial = false).unsafeRunSync()(IORuntime.global) + try + conn.close(partial = false, timeout = options.lingerDuration) + .unsafeRunSync()(IORuntime.global) catch { case NonFatal(e) => throw new Exception(e) diff --git a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala index 460531a9f..bcf0d4bdd 100644 --- a/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala +++ b/modules/scala/launcher/src/main/scala/almond/launcher/LauncherOptions.scala @@ -6,6 +6,7 @@ import caseapp._ import scala.cli.directivehandler.EitherSequence._ import scala.collection.mutable +import scala.concurrent.duration.{Duration, DurationInt} // format: off final case class LauncherOptions( @@ -34,7 +35,10 @@ final case class LauncherOptions( quiet: Option[Boolean] = None, silentImports: Option[Boolean] = None, useNotebookCoursierLogger: Option[Boolean] = None, - customDirectiveGroup: List[String] = Nil + customDirectiveGroup: List[String] = Nil, + @HelpMessage("Time given to the client to accept ZeroMQ messages before handing over the connections to the kernel. Parsed with scala.concurrent.duration.Duration, this accepts things like \"Inf\" or \"5 seconds\"") + @Hidden + linger: Option[String] = None ) { // format: on @@ -91,6 +95,12 @@ final case class LauncherOptions( groups } } + + lazy val lingerDuration = linger + .map(_.trim) + .filter(_.nonEmpty) + .map(Duration(_)) + .getOrElse(5.seconds) } object LauncherOptions { diff --git a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala index 014f0f8ef..165fb4dd8 100644 --- a/modules/scala/scala-kernel/src/main/scala/almond/Options.scala +++ b/modules/scala/scala-kernel/src/main/scala/almond/Options.scala @@ -17,6 +17,7 @@ import coursierapi.{Dependency, Module} import coursier.parse.{DependencyParser, ModuleParser} import scala.collection.compat._ +import scala.concurrent.duration.{Duration, DurationInt} import scala.jdk.CollectionConverters._ // format: off @@ -129,7 +130,11 @@ final case class Options( @HelpMessage("Pass launcher directive groups with this option. These directives will be either ignored (see --ignore-launcher-directives-in), or trigger an unused directive warning") @Hidden - launcherDirectiveGroup: List[String] = Nil + launcherDirectiveGroup: List[String] = Nil, + + @HelpMessage("""Time given to the client to accept ZeroMQ messages before exiting. Parsed with scala.concurrent.duration.Duration, this accepts things like "Inf" or "5 seconds"""") + @Hidden + linger: Option[String] = None ) { // format: on @@ -300,6 +305,11 @@ final case class Options( readFromArray(bytes)(KernelOptions.AsJson.codec) } + lazy val lingerDuration = linger + .map(_.trim) + .filter(_.nonEmpty) + .map(Duration(_)) + .getOrElse(5.seconds) } object Options { diff --git a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala index 1ee01aac3..fafff3ee1 100644 --- a/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala +++ b/modules/scala/scala-kernel/src/main/scala/almond/ScalaKernel.scala @@ -252,7 +252,8 @@ object ScalaKernel extends CaseApp[Options] { "scala", zeromqThreads, options.leftoverMessages0(), - autoClose = true + autoClose = true, + timeout = options.lingerDuration )) .unsafeRunSync()(IORuntime.global) finally diff --git a/modules/shared/channels/src/main/scala/almond/channels/Connection.scala b/modules/shared/channels/src/main/scala/almond/channels/Connection.scala index 9abbfe44d..4b2d1f715 100644 --- a/modules/shared/channels/src/main/scala/almond/channels/Connection.scala +++ b/modules/shared/channels/src/main/scala/almond/channels/Connection.scala @@ -38,7 +38,7 @@ abstract class Connection { * * Can be run multiple times. Only the first call will actually close the channels. */ - def close(partial: Boolean): IO[Unit] + def close(partial: Boolean, timeout: Duration): IO[Unit] /** Try to read a message from the specified [[Channel]]. * @@ -77,7 +77,7 @@ abstract class Connection { final def sink: Pipe[IO, (Channel, Message), Unit] = _.evalMap((send _).tupled) - final def autoCloseSink(partial: Boolean): Pipe[IO, (Channel, Message), Unit] = - s => Stream.bracket(IO.unit)(_ => close(partial)).flatMap(_ => sink(s)) + final def autoCloseSink(partial: Boolean, timeout: Duration): Pipe[IO, (Channel, Message), Unit] = + s => Stream.bracket(IO.unit)(_ => close(partial, timeout)).flatMap(_ => sink(s)) } diff --git a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala index afc0fd8b6..ea229e9e1 100644 --- a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala +++ b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqConnection.scala @@ -12,6 +12,7 @@ import org.zeromq.ZMQ.{PollItem, Poller} import zmq.ZError import scala.concurrent.duration.Duration +import cats.Parallel final class ZeromqConnection( params: ConnectionParameters, @@ -203,17 +204,17 @@ final class ZeromqConnection( .getOrElse(IO.pure(None)) }.evalOn(threads.pollingEc).flatMap(identity) - def close(partial: Boolean): IO[Unit] = { + def close(partial: Boolean, timeout: Duration): IO[Unit] = { val log0 = IO(log.debug(s"Closing channels for $params")) - val channels = Seq( + val channels = List( requests0, control0, stdin0 - ) ++ (if (partial) Nil else Seq(publish0)) + ) ::: (if (partial) Nil else List(publish0)) - val t = channels.foldLeft(IO.unit)((acc, c) => acc *> c.close) + val t = Parallel.parTraverse(channels)(_.close(timeout)) val other = IO { log.debug(s"Closing things for $params" + (if (partial) " (partial)" else "")) diff --git a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocket.scala b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocket.scala index e4abcfb9c..0ced53b88 100644 --- a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocket.scala +++ b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocket.scala @@ -13,7 +13,7 @@ trait ZeromqSocket { def open: IO[Unit] def read: IO[Option[Message]] def send(message: Message): IO[Unit] - def close: IO[Unit] + def close(timeout: Duration): IO[Unit] def channel: ZMQ.Socket } diff --git a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocketImpl.scala b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocketImpl.scala index 8139ad08e..0795f3359 100644 --- a/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocketImpl.scala +++ b/modules/shared/channels/src/main/scala/almond/channels/zeromq/ZeromqSocketImpl.scala @@ -218,10 +218,16 @@ final class ZeromqSocketImpl( }.evalOn(ec) ) - val close: IO[Unit] = { + def close(timeout: Duration): IO[Unit] = { val t = IO { if (!closed) { + val linger = timeout match { + case d: FiniteDuration => d.toMillis.toInt + case _ => -1 + } + if (channel.getLinger != linger) + channel.setLinger(linger) channel.close() closed = true } diff --git a/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqConnectionTests.scala b/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqConnectionTests.scala index d6ffaff35..967a2a70b 100644 --- a/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqConnectionTests.scala +++ b/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqConnectionTests.scala @@ -43,8 +43,8 @@ object ZeromqConnectionTests extends TestSuite { _ = assert(resp._1 == Channel.Requests) _ = assert(resp._2.copy(idents = Nil) == msg0) // TODO Enforce this is run via bracketing - _ <- kernel.close(partial = false) - _ <- server.close(partial = false) + _ <- kernel.close(partial = false, timeout = 2.seconds) + _ <- server.close(partial = false, timeout = 2.seconds) } yield () t.unsafeRunSync()(IORuntime.global) diff --git a/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqSocketTests.scala b/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqSocketTests.scala index ce92c989b..990e89a75 100644 --- a/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqSocketTests.scala +++ b/modules/shared/channels/src/test/scala/almond/channels/zeromq/ZeromqSocketTests.scala @@ -10,6 +10,7 @@ import org.zeromq.{SocketType, ZMQ} import utest._ import scala.concurrent.ExecutionContext +import scala.concurrent.duration.DurationInt import java.nio.charset.StandardCharsets object ZeromqSocketTests extends TestSuite { @@ -86,8 +87,8 @@ object ZeromqSocketTests extends TestSuite { readOpt <- rep.read _ = assert(readOpt.contains(msg)) // FIXME Closing should be enforced via bracketing - _ <- req.close - _ <- rep.close + _ <- req.close(timeout = 5.seconds) + _ <- rep.close(timeout = 5.seconds) } yield () t.unsafeRunSync()(IORuntime.global) @@ -151,8 +152,8 @@ object ZeromqSocketTests extends TestSuite { readOpt <- rep.read _ = assert(readOpt.contains(msg)) // FIXME Closing should be enforced via bracketing - _ <- req.close - _ <- rep.close + _ <- req.close(timeout = 5.seconds) + _ <- rep.close(timeout = 5.seconds) } yield () t.unsafeRunSync()(IORuntime.global) diff --git a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala index a18e898df..58e97f3c7 100644 --- a/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala +++ b/modules/shared/kernel/src/main/scala/almond/kernel/Kernel.scala @@ -24,6 +24,7 @@ import fs2.{Pipe, Stream} import scala.concurrent.ExecutionContext import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.Duration final case class Kernel( interpreter: IOInterpreter, @@ -186,7 +187,8 @@ final case class Kernel( connection: ConnectionParameters, kernelId: String, zeromqThreads: ZeromqThreads, - leftoverMessages: Seq[(Channel, RawMessage)] + leftoverMessages: Seq[(Channel, RawMessage)], + timeout: Duration ): IO[Unit] = for { t <- runOnConnectionAllowClose0( @@ -194,7 +196,8 @@ final case class Kernel( kernelId, zeromqThreads, leftoverMessages, - autoClose = true + autoClose = true, + timeout = timeout ) (run, _) = t _ <- run @@ -205,7 +208,8 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[(IO[Unit], Connection)] = for { c <- connection.channels( @@ -219,7 +223,11 @@ final case class Kernel( val run0 = for { _ <- c.open - _ <- run(c.stream(), c.autoCloseSink(partial = !autoClose), leftoverMessages) + _ <- run( + c.stream(), + c.autoCloseSink(partial = !autoClose, timeout = timeout), + leftoverMessages + ) } yield () (run0, c) } @@ -236,14 +244,16 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] = runOnConnectionAllowClose0( connection, kernelId, zeromqThreads, leftoverMessages, - autoClose + autoClose, + timeout ).map { case (run, conn) => val run0 = run.attempt.flatMap { @@ -264,7 +274,8 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] = for { _ <- { @@ -285,7 +296,8 @@ final case class Kernel( kernelId, zeromqThreads, leftoverMessages, - autoClose + autoClose, + timeout ) } yield value @@ -294,7 +306,8 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[Unit] = for { t <- runOnConnectionFileAllowClose( @@ -302,7 +315,8 @@ final case class Kernel( kernelId, zeromqThreads, leftoverMessages, - autoClose + autoClose, + timeout ) (run, _) = t _ <- run @@ -313,7 +327,8 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[Unit] = for { t <- runOnConnectionFileAllowClose( @@ -321,7 +336,8 @@ final case class Kernel( kernelId, zeromqThreads, leftoverMessages, - autoClose + autoClose, + timeout ) (run, _) = t _ <- run @@ -332,14 +348,16 @@ final case class Kernel( kernelId: String, zeromqThreads: ZeromqThreads, leftoverMessages: Seq[(Channel, RawMessage)], - autoClose: Boolean + autoClose: Boolean, + timeout: Duration ): IO[(IO[Seq[(Channel, RawMessage)]], Connection)] = runOnConnectionFileAllowClose( Paths.get(connectionPath), kernelId, zeromqThreads, leftoverMessages, - autoClose + autoClose, + timeout ) }