Skip to content

Commit

Permalink
Set linger to a finite value before closing ZeroMQ channels
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarchambault committed May 28, 2024
1 parent d699cea commit c011552
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 40 deletions.
9 changes: 8 additions & 1 deletion modules/echo/src/main/scala/almond/echo/EchoKernel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ object EchoKernel extends CaseApp[Options] {

log.debug("Running kernel")
Kernel.create(new EchoInterpreter, interpreterEc, kernelThreads, logCtx)
.flatMap(_.runOnConnectionFile(connectionFile, "echo", zeromqThreads, Nil, autoClose = true))
.flatMap(_.runOnConnectionFile(
connectionFile,
"echo",
zeromqThreads,
Nil,
autoClose = true,
timeout = options.lingerDuration
))
.unsafeRunSync()(IORuntime.global)
}
}
18 changes: 16 additions & 2 deletions modules/echo/src/main/scala/almond/echo/Options.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@ import almond.kernel.install.{Options => InstallOptions}
import caseapp.{HelpMessage, Recurse}
import caseapp.core.help.Help
import caseapp.core.parser.Parser
import caseapp.Hidden
import scala.concurrent.duration.{Duration, DurationInt}

final case class Options(
connectionFile: Option[String] = None,
@HelpMessage("Log level (one of none, error, warn, info, or debug)")
log: String = "warn",
install: Boolean = false,
@Recurse
installOptions: InstallOptions = InstallOptions()
)
installOptions: InstallOptions = InstallOptions(),
@HelpMessage(
"""Time given to the client to accept ZeroMQ messages before exiting. Parsed with scala.concurrent.duration.Duration, this accepts things like "Inf" or "5 seconds""""
)
@Hidden
linger: Option[String] = None
) {

lazy val lingerDuration = linger
.map(_.trim)
.filter(_.nonEmpty)
.map(Duration(_))
.getOrElse(5.seconds)
}

object Options {
implicit lazy val parser: Parser[Options] = Parser.derive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ class KernelLauncher(
}

def close(): Unit = {
conn.close(partial = false).unsafeRunTimed(2.minutes)(IORuntime.global).getOrElse {
sys.error("Timeout when closing ZeroMQ connections")
}
conn.close(partial = false, timeout = 30.seconds)
.unsafeRunTimed(2.minutes)(IORuntime.global)
.getOrElse {
sys.error("Timeout when closing ZeroMQ connections")
}

if (perTestZeroMqContext) {
val t = stackTracePrinterThread(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import java.nio.channels.ClosedSelectorException
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import scala.concurrent.duration.Duration

object Launcher extends CaseApp[LauncherOptions] {

Expand Down Expand Up @@ -325,7 +326,8 @@ object Launcher extends CaseApp[LauncherOptions] {
"scala",
zeromqThreads,
Nil,
autoClose = false
autoClose = false,
timeout = Duration.Inf // unused here
))
.unsafeRunSync()(IORuntime.global)
val leftoverMessages: Seq[(Channel, RawMessage)] = run.unsafeRunSync()(IORuntime.global)
Expand Down Expand Up @@ -410,7 +412,9 @@ object Launcher extends CaseApp[LauncherOptions] {
for (outputHandler <- outputHandlerOpt)
outputHandler.done()

try conn.close(partial = false).unsafeRunSync()(IORuntime.global)
try
conn.close(partial = false, timeout = options.lingerDuration)
.unsafeRunSync()(IORuntime.global)
catch {
case NonFatal(e) =>
throw new Exception(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import caseapp._

import scala.cli.directivehandler.EitherSequence._
import scala.collection.mutable
import scala.concurrent.duration.{Duration, DurationInt}

// format: off
final case class LauncherOptions(
Expand Down Expand Up @@ -34,7 +35,10 @@ final case class LauncherOptions(
quiet: Option[Boolean] = None,
silentImports: Option[Boolean] = None,
useNotebookCoursierLogger: Option[Boolean] = None,
customDirectiveGroup: List[String] = Nil
customDirectiveGroup: List[String] = Nil,
@HelpMessage("Time given to the client to accept ZeroMQ messages before handing over the connections to the kernel. Parsed with scala.concurrent.duration.Duration, this accepts things like \"Inf\" or \"5 seconds\"")
@Hidden
linger: Option[String] = None
) {
// format: on

Expand Down Expand Up @@ -91,6 +95,12 @@ final case class LauncherOptions(
groups
}
}

lazy val lingerDuration = linger
.map(_.trim)
.filter(_.nonEmpty)
.map(Duration(_))
.getOrElse(5.seconds)
}

object LauncherOptions {
Expand Down
12 changes: 11 additions & 1 deletion modules/scala/scala-kernel/src/main/scala/almond/Options.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import coursierapi.{Dependency, Module}
import coursier.parse.{DependencyParser, ModuleParser}

import scala.collection.compat._
import scala.concurrent.duration.{Duration, DurationInt}
import scala.jdk.CollectionConverters._

// format: off
Expand Down Expand Up @@ -129,7 +130,11 @@ final case class Options(

@HelpMessage("Pass launcher directive groups with this option. These directives will be either ignored (see --ignore-launcher-directives-in), or trigger an unused directive warning")
@Hidden
launcherDirectiveGroup: List[String] = Nil
launcherDirectiveGroup: List[String] = Nil,

@HelpMessage("""Time given to the client to accept ZeroMQ messages before exiting. Parsed with scala.concurrent.duration.Duration, this accepts things like "Inf" or "5 seconds"""")
@Hidden
linger: Option[String] = None
) {
// format: on

Expand Down Expand Up @@ -300,6 +305,11 @@ final case class Options(
readFromArray(bytes)(KernelOptions.AsJson.codec)
}

lazy val lingerDuration = linger
.map(_.trim)
.filter(_.nonEmpty)
.map(Duration(_))
.getOrElse(5.seconds)
}

object Options {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ object ScalaKernel extends CaseApp[Options] {
"scala",
zeromqThreads,
options.leftoverMessages0(),
autoClose = true
autoClose = true,
timeout = options.lingerDuration
))
.unsafeRunSync()(IORuntime.global)
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ abstract class Connection {
*
* Can be run multiple times. Only the first call will actually close the channels.
*/
def close(partial: Boolean): IO[Unit]
def close(partial: Boolean, timeout: Duration): IO[Unit]

/** Try to read a message from the specified [[Channel]].
*
Expand Down Expand Up @@ -77,7 +77,7 @@ abstract class Connection {
final def sink: Pipe[IO, (Channel, Message), Unit] =
_.evalMap((send _).tupled)

final def autoCloseSink(partial: Boolean): Pipe[IO, (Channel, Message), Unit] =
s => Stream.bracket(IO.unit)(_ => close(partial)).flatMap(_ => sink(s))
final def autoCloseSink(partial: Boolean, timeout: Duration): Pipe[IO, (Channel, Message), Unit] =
s => Stream.bracket(IO.unit)(_ => close(partial, timeout)).flatMap(_ => sink(s))

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.zeromq.ZMQ.{PollItem, Poller}
import zmq.ZError

import scala.concurrent.duration.Duration
import cats.Parallel

final class ZeromqConnection(
params: ConnectionParameters,
Expand Down Expand Up @@ -203,17 +204,17 @@ final class ZeromqConnection(
.getOrElse(IO.pure(None))
}.evalOn(threads.pollingEc).flatMap(identity)

def close(partial: Boolean): IO[Unit] = {
def close(partial: Boolean, timeout: Duration): IO[Unit] = {

val log0 = IO(log.debug(s"Closing channels for $params"))

val channels = Seq(
val channels = List(
requests0,
control0,
stdin0
) ++ (if (partial) Nil else Seq(publish0))
) ::: (if (partial) Nil else List(publish0))

val t = channels.foldLeft(IO.unit)((acc, c) => acc *> c.close)
val t = Parallel.parTraverse(channels)(_.close(timeout))

val other = IO {
log.debug(s"Closing things for $params" + (if (partial) " (partial)" else ""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trait ZeromqSocket {
def open: IO[Unit]
def read: IO[Option[Message]]
def send(message: Message): IO[Unit]
def close: IO[Unit]
def close(timeout: Duration): IO[Unit]

def channel: ZMQ.Socket
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,16 @@ final class ZeromqSocketImpl(
}.evalOn(ec)
)

val close: IO[Unit] = {
def close(timeout: Duration): IO[Unit] = {

val t = IO {
if (!closed) {
val linger = timeout match {
case d: FiniteDuration => d.toMillis.toInt
case _ => -1
}
if (channel.getLinger != linger)
channel.setLinger(linger)
channel.close()
closed = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ object ZeromqConnectionTests extends TestSuite {
_ = assert(resp._1 == Channel.Requests)
_ = assert(resp._2.copy(idents = Nil) == msg0)
// TODO Enforce this is run via bracketing
_ <- kernel.close(partial = false)
_ <- server.close(partial = false)
_ <- kernel.close(partial = false, timeout = 2.seconds)
_ <- server.close(partial = false, timeout = 2.seconds)
} yield ()

t.unsafeRunSync()(IORuntime.global)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.zeromq.{SocketType, ZMQ}
import utest._

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.DurationInt
import java.nio.charset.StandardCharsets

object ZeromqSocketTests extends TestSuite {
Expand Down Expand Up @@ -86,8 +87,8 @@ object ZeromqSocketTests extends TestSuite {
readOpt <- rep.read
_ = assert(readOpt.contains(msg))
// FIXME Closing should be enforced via bracketing
_ <- req.close
_ <- rep.close
_ <- req.close(timeout = 5.seconds)
_ <- rep.close(timeout = 5.seconds)
} yield ()

t.unsafeRunSync()(IORuntime.global)
Expand Down Expand Up @@ -151,8 +152,8 @@ object ZeromqSocketTests extends TestSuite {
readOpt <- rep.read
_ = assert(readOpt.contains(msg))
// FIXME Closing should be enforced via bracketing
_ <- req.close
_ <- rep.close
_ <- req.close(timeout = 5.seconds)
_ <- rep.close(timeout = 5.seconds)
} yield ()

t.unsafeRunSync()(IORuntime.global)
Expand Down
Loading

0 comments on commit c011552

Please sign in to comment.