From 0d7ad11b3810ce9bf8e019cf232fbf9798182641 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 19 Jan 2025 12:15:34 +0800 Subject: [PATCH] feat: Add support for switching scheduler --- .../ForkJoinPoolVirtualThreadSpec.scala | 69 ++++++++++++++++ actor/src/main/resources/reference.conf | 12 +++ .../pekko/dispatch/AbstractDispatcher.scala | 2 +- .../ForkJoinExecutorConfigurator.scala | 78 ++++++++++++++++--- .../pekko/dispatch/ThreadPoolBuilder.scala | 12 +++ .../pekko/dispatch/VirtualThreadSupport.scala | 68 ++++++++++++---- .../dispatch/VirtualizedExecutorService.scala | 7 -- docs/src/main/paradox/dispatchers.md | 10 +++ docs/src/main/paradox/typed/dispatchers.md | 11 +++ .../docs/dispatcher/DispatcherDocSpec.scala | 4 + project/JdkOptions.scala | 3 + 11 files changed, 241 insertions(+), 35 deletions(-) create mode 100644 actor-tests/src/test/scala-jdk21-only/org/apache/pekko/dispatch/ForkJoinPoolVirtualThreadSpec.scala diff --git a/actor-tests/src/test/scala-jdk21-only/org/apache/pekko/dispatch/ForkJoinPoolVirtualThreadSpec.scala b/actor-tests/src/test/scala-jdk21-only/org/apache/pekko/dispatch/ForkJoinPoolVirtualThreadSpec.scala new file mode 100644 index 00000000000..455a4e2f1f9 --- /dev/null +++ b/actor-tests/src/test/scala-jdk21-only/org/apache/pekko/dispatch/ForkJoinPoolVirtualThreadSpec.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.dispatch + +import com.typesafe.config.ConfigFactory + +import org.apache.pekko +import pekko.actor.{ Actor, Props } +import pekko.testkit.{ ImplicitSender, PekkoSpec } +import pekko.util.JavaVersion + +object ForkJoinPoolVirtualThreadSpec { + val config = ConfigFactory.parseString(""" + |custom { + | task-dispatcher { + | mailbox-type = "org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox" + | throughput = 5 + | fork-join-executor { + | parallelism-factor = 2 + | parallelism-max = 2 + | parallelism-min = 2 + | virtualize = on + | } + | } + |} + """.stripMargin) + + class ThreadNameActor extends Actor { + + override def receive = { + case "ping" => + sender() ! Thread.currentThread().getName + } + } + +} + +class ForkJoinPoolVirtualThreadSpec extends PekkoSpec(ForkJoinPoolVirtualThreadSpec.config) with ImplicitSender { + import ForkJoinPoolVirtualThreadSpec._ + + "PekkoForkJoinPool" must { + + "support virtualization with Virtual Thread" in { + val actor = system.actorOf(Props(new ThreadNameActor).withDispatcher("custom.task-dispatcher")) + for (_ <- 1 to 1000) { + actor ! "ping" + expectMsgPF() { case name: String => + name should include("ForkJoinPoolVirtualThreadSpec-custom.task-dispatcher-virtual-thread-") + } + } + } + + } +} diff --git a/actor/src/main/resources/reference.conf b/actor/src/main/resources/reference.conf index 6831cf056db..2cab15a2945 100644 --- a/actor/src/main/resources/reference.conf +++ b/actor/src/main/resources/reference.conf @@ -487,6 +487,18 @@ pekko { # This config is new in Pekko v1.1.0 and only has an effect if you are running with JDK 9 and above. # Read the documentation on `java.util.concurrent.ForkJoinPool` to find out more. Default in hex is 0x7fff. maximum-pool-size = 32767 + + # This config is new in Pekko v1.2.0 and only has an effect if you are running with JDK 21 and above, + # When set to `on` but underlying runtime does not support virtual threads, an Exception will throw. + # Virtualize this dispatcher as a virtual-thread-executor + # Valid values are: `on`, `off` + # + # Requirements: + # 1. JDK 21+ + # 2. add options to the JVM: + # --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED + # --add-opens=java.base/java.lang=ALL-UNNAMED + virtualize = off } # This will be used if you have set "executor = "thread-pool-executor"" diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala b/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala index d73cff521a7..147bedc874a 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala @@ -453,7 +453,7 @@ final class VirtualThreadExecutorConfigurator(config: Config, prerequisites: Dis } } new VirtualizedExecutorService( - tf, + tf, // the virtual thread factory pool, // the default scheduler of virtual thread loadMetricsProvider, cascadeShutdown = false // we don't want to cascade shutdown the default virtual thread scheduler diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/ForkJoinExecutorConfigurator.scala b/actor/src/main/scala/org/apache/pekko/dispatch/ForkJoinExecutorConfigurator.scala index 661dfb8dd70..69c714d007d 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/ForkJoinExecutorConfigurator.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/ForkJoinExecutorConfigurator.scala @@ -14,13 +14,14 @@ package org.apache.pekko.dispatch import com.typesafe.config.Config +import org.apache.pekko +import pekko.dispatch.VirtualThreadSupport.newVirtualThreadFactory +import pekko.util.JavaVersion import java.lang.invoke.{ MethodHandle, MethodHandles, MethodType } -import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory } +import java.util.concurrent.{ Executor, ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory } import scala.util.Try -import org.apache.pekko.util.JavaVersion - object ForkJoinExecutorConfigurator { /** @@ -86,15 +87,28 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer } class ForkJoinExecutorServiceFactory( + val id: String, val threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, val parallelism: Int, val asyncMode: Boolean, - val maxPoolSize: Int) + val maxPoolSize: Int, + val virtualize: Boolean) extends ExecutorServiceFactory { + def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, + parallelism: Int, + asyncMode: Boolean, + maxPoolSize: Int, + virtualize: Boolean) = + this(null, threadFactory, parallelism, asyncMode, maxPoolSize, virtualize) def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, parallelism: Int, - asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap) + asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap, false) + + def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, + parallelism: Int, + asyncMode: Boolean, + maxPoolSize: Int) = this(threadFactory, parallelism, asyncMode, maxPoolSize, false) private def pekkoJdk9ForkJoinPoolClassOpt: Option[Class[_]] = Try(Class.forName("org.apache.pekko.dispatch.PekkoJdk9ForkJoinPool")).toOption @@ -116,12 +130,50 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, parallelism: Int) = this(threadFactory, parallelism, asyncMode = true) - def createExecutorService: ExecutorService = pekkoJdk9ForkJoinPoolHandleOpt match { - case Some(handle) => - handle.invoke(parallelism, threadFactory, maxPoolSize, - MonitorableThreadFactory.doNothing, asyncMode).asInstanceOf[ExecutorService] - case _ => - new PekkoForkJoinPool(parallelism, threadFactory, MonitorableThreadFactory.doNothing, asyncMode) + def createExecutorService: ExecutorService = { + val tf = if (virtualize && JavaVersion.majorVersion >= 21) { + threadFactory match { + // we need to use the thread factory to create carrier thread + case m: MonitorableThreadFactory => new MonitorableCarrierThreadFactory(m.name) + case _ => threadFactory + } + } else threadFactory + + val pool = pekkoJdk9ForkJoinPoolHandleOpt match { + case Some(handle) => + // carrier Thread only exists in JDK 17+ + handle.invoke(parallelism, tf, maxPoolSize, MonitorableThreadFactory.doNothing, asyncMode) + .asInstanceOf[ExecutorService with LoadMetrics] + case _ => + new PekkoForkJoinPool(parallelism, tf, MonitorableThreadFactory.doNothing, asyncMode) + } + + if (virtualize && JavaVersion.majorVersion >= 21) { + // when virtualized, we need enhanced thread factory + val factory: ThreadFactory = threadFactory match { + case MonitorableThreadFactory(name, _, contextClassLoader, exceptionHandler, _) => + new ThreadFactory { + private val vtFactory = newVirtualThreadFactory(name, pool) // use the pool as the scheduler + + override def newThread(r: Runnable): Thread = { + val vt = vtFactory.newThread(r) + vt.setUncaughtExceptionHandler(exceptionHandler) + contextClassLoader.foreach(vt.setContextClassLoader) + vt + } + } + case _ => newVirtualThreadFactory(prerequisites.settings.name, pool); // use the pool as the scheduler + } + // wrap the pool with virtualized executor service + new VirtualizedExecutorService( + factory, // the virtual thread factory + pool, // the underlying pool + (_: Executor) => pool.atFullThrottle(), // the load metrics provider, we use the pool itself + cascadeShutdown = true // cascade shutdown + ) + } else { + pool + } } } @@ -143,12 +195,14 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer } new ForkJoinExecutorServiceFactory( + id, validate(tf), ThreadPoolConfig.scaledPoolSize( config.getInt("parallelism-min"), config.getDouble("parallelism-factor"), config.getInt("parallelism-max")), asyncMode, - config.getInt("maximum-pool-size")) + config.getInt("maximum-pool-size"), + config.getBoolean("virtualize")) } } diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/ThreadPoolBuilder.scala b/actor/src/main/scala/org/apache/pekko/dispatch/ThreadPoolBuilder.scala index 205c2e4ac77..cf5a856ac2c 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/ThreadPoolBuilder.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/ThreadPoolBuilder.scala @@ -235,6 +235,18 @@ final case class MonitorableThreadFactory( } } +class MonitorableCarrierThreadFactory(name: String) + extends ForkJoinPool.ForkJoinWorkerThreadFactory { + private val counter = new AtomicLong(0L) + + def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = { + val thread = VirtualThreadSupport.CarrierThreadFactory.newThread(pool) + // Name of the threads for the ForkJoinPool are not customizable. Change it here. + thread.setName(name + "-" + "CarrierThread" + "-" + counter.incrementAndGet()) + thread + } +} + /** * As the name says */ diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala index 955cdc17102..32723093728 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala @@ -17,11 +17,12 @@ package org.apache.pekko.dispatch -import org.apache.pekko.annotation.InternalApi -import org.apache.pekko.util.JavaVersion +import org.apache.pekko +import pekko.annotation.InternalApi +import pekko.util.JavaVersion import java.lang.invoke.{ MethodHandles, MethodType } -import java.util.concurrent.{ ExecutorService, ForkJoinPool, ThreadFactory } +import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory } import scala.util.control.NonFatal @InternalApi @@ -34,8 +35,26 @@ private[dispatch] object VirtualThreadSupport { val isSupported: Boolean = JavaVersion.majorVersion >= 21 /** - * Create a virtual thread factory with a executor, the executor will be used as the scheduler of - * virtual thread. + * Create a newThreadPerTaskExecutor with the specified thread factory. + */ + def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = { + require(threadFactory != null, "threadFactory should not be null.") + try { + val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors") + val newThreadPerTaskExecutorMethod = lookup.findStatic( + executorsClazz, + "newThreadPerTaskExecutor", + MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory])) + newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService] + } catch { + case NonFatal(e) => + // --add-opens java.base/java.lang=ALL-UNNAMED + throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e) + } + } + + /** + * Create a virtual thread factory with the default Virtual Thread executor. */ def newVirtualThreadFactory(prefix: String): ThreadFactory = { require(isSupported, "Virtual thread is not supported.") @@ -57,19 +76,38 @@ private[dispatch] object VirtualThreadSupport { } } - def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = { - require(threadFactory != null, "threadFactory should not be null.") + /** + * Create a virtual thread factory with the specified executor as the scheduler of virtual thread. + */ + def newVirtualThreadFactory(prefix: String, executor: ExecutorService): ThreadFactory = try { - val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors") - val newThreadPerTaskExecutorMethod = lookup.findStatic( - executorsClazz, - "newThreadPerTaskExecutor", - MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory])) - newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService] + val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder") + val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual") + val ofVirtualMethod = classOf[Thread].getDeclaredMethod("ofVirtual") + var builder = ofVirtualMethod.invoke(null) + if (executor != null) { + val clazz = builder.getClass + val field = clazz.getDeclaredField("scheduler") + field.setAccessible(true) + field.set(builder, executor) + } + val nameMethod = ofVirtualClass.getDeclaredMethod("name", classOf[String], classOf[Long]) + val factoryMethod = builderClass.getDeclaredMethod("factory") + val zero = java.lang.Long.valueOf(0L) + builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", zero) + factoryMethod.invoke(builder).asInstanceOf[ThreadFactory] } catch { case NonFatal(e) => // --add-opens java.base/java.lang=ALL-UNNAMED - throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e) + throw new UnsupportedOperationException("Failed to create virtual thread factory", e) + } + + object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory { + private val clazz = ClassLoader.getSystemClassLoader.loadClass("jdk.internal.misc.CarrierThread") + // TODO lookup.findClass is only available in Java 9 + private val constructor = clazz.getDeclaredConstructor(classOf[ForkJoinPool]) + override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = { + constructor.newInstance(pool).asInstanceOf[ForkJoinWorkerThread] } } @@ -79,7 +117,7 @@ private[dispatch] object VirtualThreadSupport { def getVirtualThreadDefaultScheduler: ForkJoinPool = try { require(isSupported, "Virtual thread is not supported.") - val clazz = Class.forName("java.lang.VirtualThread") + val clazz = ClassLoader.getSystemClassLoader.loadClass("java.lang.VirtualThread") val fieldName = "DEFAULT_SCHEDULER" val field = clazz.getDeclaredField(fieldName) field.setAccessible(true) diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala index 631c34fdc3e..cefd5ab063a 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala @@ -39,13 +39,6 @@ final class VirtualizedExecutorService( require(vtFactory != null, "Virtual thread factory must not be null") require(loadMetricsProvider != null, "Load metrics provider must not be null") - def this(prefix: String, - underlying: ExecutorService, - loadMetricsProvider: Executor => Boolean, - cascadeShutdown: Boolean) = { - this(VirtualThreadSupport.newVirtualThreadFactory(prefix), underlying, loadMetricsProvider, cascadeShutdown) - } - private val executor = VirtualThreadSupport.newThreadPerTaskExecutor(vtFactory) override def atFullThrottle(): Boolean = loadMetricsProvider(this) diff --git a/docs/src/main/paradox/dispatchers.md b/docs/src/main/paradox/dispatchers.md index cac3d0f8f3d..df78d9e77b4 100644 --- a/docs/src/main/paradox/dispatchers.md +++ b/docs/src/main/paradox/dispatchers.md @@ -44,6 +44,16 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool. +**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature. +When using virtual threads, all virtual threads will use the same `unparker`, so you may want to +increase the number of `jdk.unparker.maxPoolSize`. + +#### Requirements: + +1. JDK 21+ +2. add options to the JVM: + - `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED` + - `--add-opens=java.base/java.lang=ALL-UNNAMED` @@@ Another example that uses the "thread-pool-executor": diff --git a/docs/src/main/paradox/typed/dispatchers.md b/docs/src/main/paradox/typed/dispatchers.md index d65312e136e..3669dda2d8d 100644 --- a/docs/src/main/paradox/typed/dispatchers.md +++ b/docs/src/main/paradox/typed/dispatchers.md @@ -129,6 +129,17 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool. +**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature. +When using virtual threads, all virtual threads will use the same `unparker`, so you may want to +increase the number of `jdk.unparker.maxPoolSize`. + +#### Requirements: + +1. JDK 21+ +2. add options to the JVM: + - `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED` + - `--add-opens=java.base/java.lang=ALL-UNNAMED` + @@@ @@@ note diff --git a/docs/src/test/scala/docs/dispatcher/DispatcherDocSpec.scala b/docs/src/test/scala/docs/dispatcher/DispatcherDocSpec.scala index e9532ddabd8..7a4213266bb 100644 --- a/docs/src/test/scala/docs/dispatcher/DispatcherDocSpec.scala +++ b/docs/src/test/scala/docs/dispatcher/DispatcherDocSpec.scala @@ -66,6 +66,10 @@ object DispatcherDocSpec { parallelism-factor = 2.0 # Max number of threads to cap factor-based parallelism number to parallelism-max = 10 + + # NOTE: THIS IS AN ADVANCED OPTION, USE WITH CAUTION, requires Java 21+ + # Virtualize this dispatcher as a virtual-thread-executor + virtualize = off } # Throughput defines the maximum number of messages to be # processed per actor before the thread jumps to the next actor. diff --git a/project/JdkOptions.scala b/project/JdkOptions.scala index 852a2bb82da..269ef1a77ee 100644 --- a/project/JdkOptions.scala +++ b/project/JdkOptions.scala @@ -49,6 +49,9 @@ object JdkOptions extends AutoPlugin { lazy val versionSpecificJavaOptions = if (isJdk17orHigher) { + // for virtual threads + "--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED" :: + "--add-opens=java.base/java.lang=ALL-UNNAMED" :: // for aeron "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED" :: // for LevelDB