Skip to content

Commit

Permalink
feat: Add support for switching scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Jan 22, 2025
1 parent 189c893 commit 0d7ad11
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.dispatch

import com.typesafe.config.ConfigFactory

import org.apache.pekko
import pekko.actor.{ Actor, Props }
import pekko.testkit.{ ImplicitSender, PekkoSpec }
import pekko.util.JavaVersion

object ForkJoinPoolVirtualThreadSpec {
val config = ConfigFactory.parseString("""
|custom {
| task-dispatcher {
| mailbox-type = "org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
| throughput = 5
| fork-join-executor {
| parallelism-factor = 2
| parallelism-max = 2
| parallelism-min = 2
| virtualize = on
| }
| }
|}
""".stripMargin)

class ThreadNameActor extends Actor {

override def receive = {
case "ping" =>
sender() ! Thread.currentThread().getName
}
}

}

class ForkJoinPoolVirtualThreadSpec extends PekkoSpec(ForkJoinPoolVirtualThreadSpec.config) with ImplicitSender {
import ForkJoinPoolVirtualThreadSpec._

"PekkoForkJoinPool" must {

"support virtualization with Virtual Thread" in {
val actor = system.actorOf(Props(new ThreadNameActor).withDispatcher("custom.task-dispatcher"))
for (_ <- 1 to 1000) {
actor ! "ping"
expectMsgPF() { case name: String =>
name should include("ForkJoinPoolVirtualThreadSpec-custom.task-dispatcher-virtual-thread-")
}
}
}

}
}
12 changes: 12 additions & 0 deletions actor/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,18 @@ pekko {
# This config is new in Pekko v1.1.0 and only has an effect if you are running with JDK 9 and above.
# Read the documentation on `java.util.concurrent.ForkJoinPool` to find out more. Default in hex is 0x7fff.
maximum-pool-size = 32767

# This config is new in Pekko v1.2.0 and only has an effect if you are running with JDK 21 and above,
# When set to `on` but underlying runtime does not support virtual threads, an Exception will throw.
# Virtualize this dispatcher as a virtual-thread-executor
# Valid values are: `on`, `off`
#
# Requirements:
# 1. JDK 21+
# 2. add options to the JVM:
# --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED
# --add-opens=java.base/java.lang=ALL-UNNAMED
virtualize = off
}

# This will be used if you have set "executor = "thread-pool-executor""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ final class VirtualThreadExecutorConfigurator(config: Config, prerequisites: Dis
}
}
new VirtualizedExecutorService(
tf,
tf, // the virtual thread factory
pool, // the default scheduler of virtual thread
loadMetricsProvider,
cascadeShutdown = false // we don't want to cascade shutdown the default virtual thread scheduler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
package org.apache.pekko.dispatch

import com.typesafe.config.Config
import org.apache.pekko
import pekko.dispatch.VirtualThreadSupport.newVirtualThreadFactory
import pekko.util.JavaVersion

import java.lang.invoke.{ MethodHandle, MethodHandles, MethodType }
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory }
import java.util.concurrent.{ Executor, ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory }
import scala.util.Try

import org.apache.pekko.util.JavaVersion

object ForkJoinExecutorConfigurator {

/**
Expand Down Expand Up @@ -86,15 +87,28 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
}

class ForkJoinExecutorServiceFactory(
val id: String,
val threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
val parallelism: Int,
val asyncMode: Boolean,
val maxPoolSize: Int)
val maxPoolSize: Int,
val virtualize: Boolean)
extends ExecutorServiceFactory {
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
parallelism: Int,
asyncMode: Boolean,
maxPoolSize: Int,
virtualize: Boolean) =
this(null, threadFactory, parallelism, asyncMode, maxPoolSize, virtualize)

def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
parallelism: Int,
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap)
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap, false)

def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
parallelism: Int,
asyncMode: Boolean,
maxPoolSize: Int) = this(threadFactory, parallelism, asyncMode, maxPoolSize, false)

private def pekkoJdk9ForkJoinPoolClassOpt: Option[Class[_]] =
Try(Class.forName("org.apache.pekko.dispatch.PekkoJdk9ForkJoinPool")).toOption
Expand All @@ -116,12 +130,50 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, parallelism: Int) =
this(threadFactory, parallelism, asyncMode = true)

def createExecutorService: ExecutorService = pekkoJdk9ForkJoinPoolHandleOpt match {
case Some(handle) =>
handle.invoke(parallelism, threadFactory, maxPoolSize,
MonitorableThreadFactory.doNothing, asyncMode).asInstanceOf[ExecutorService]
case _ =>
new PekkoForkJoinPool(parallelism, threadFactory, MonitorableThreadFactory.doNothing, asyncMode)
def createExecutorService: ExecutorService = {
val tf = if (virtualize && JavaVersion.majorVersion >= 21) {
threadFactory match {
// we need to use the thread factory to create carrier thread
case m: MonitorableThreadFactory => new MonitorableCarrierThreadFactory(m.name)
case _ => threadFactory
}
} else threadFactory

val pool = pekkoJdk9ForkJoinPoolHandleOpt match {
case Some(handle) =>
// carrier Thread only exists in JDK 17+
handle.invoke(parallelism, tf, maxPoolSize, MonitorableThreadFactory.doNothing, asyncMode)
.asInstanceOf[ExecutorService with LoadMetrics]
case _ =>
new PekkoForkJoinPool(parallelism, tf, MonitorableThreadFactory.doNothing, asyncMode)
}

if (virtualize && JavaVersion.majorVersion >= 21) {
// when virtualized, we need enhanced thread factory
val factory: ThreadFactory = threadFactory match {
case MonitorableThreadFactory(name, _, contextClassLoader, exceptionHandler, _) =>
new ThreadFactory {
private val vtFactory = newVirtualThreadFactory(name, pool) // use the pool as the scheduler

override def newThread(r: Runnable): Thread = {
val vt = vtFactory.newThread(r)
vt.setUncaughtExceptionHandler(exceptionHandler)
contextClassLoader.foreach(vt.setContextClassLoader)
vt
}
}
case _ => newVirtualThreadFactory(prerequisites.settings.name, pool); // use the pool as the scheduler
}
// wrap the pool with virtualized executor service
new VirtualizedExecutorService(
factory, // the virtual thread factory
pool, // the underlying pool
(_: Executor) => pool.atFullThrottle(), // the load metrics provider, we use the pool itself
cascadeShutdown = true // cascade shutdown
)
} else {
pool
}
}
}

Expand All @@ -143,12 +195,14 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
}

new ForkJoinExecutorServiceFactory(
id,
validate(tf),
ThreadPoolConfig.scaledPoolSize(
config.getInt("parallelism-min"),
config.getDouble("parallelism-factor"),
config.getInt("parallelism-max")),
asyncMode,
config.getInt("maximum-pool-size"))
config.getInt("maximum-pool-size"),
config.getBoolean("virtualize"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ final case class MonitorableThreadFactory(
}
}

class MonitorableCarrierThreadFactory(name: String)
extends ForkJoinPool.ForkJoinWorkerThreadFactory {
private val counter = new AtomicLong(0L)

def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
val thread = VirtualThreadSupport.CarrierThreadFactory.newThread(pool)
// Name of the threads for the ForkJoinPool are not customizable. Change it here.
thread.setName(name + "-" + "CarrierThread" + "-" + counter.incrementAndGet())
thread
}
}

/**
* As the name says
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.pekko.dispatch

import org.apache.pekko.annotation.InternalApi
import org.apache.pekko.util.JavaVersion
import org.apache.pekko
import pekko.annotation.InternalApi
import pekko.util.JavaVersion

import java.lang.invoke.{ MethodHandles, MethodType }
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ThreadFactory }
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory }
import scala.util.control.NonFatal

@InternalApi
Expand All @@ -34,8 +35,26 @@ private[dispatch] object VirtualThreadSupport {
val isSupported: Boolean = JavaVersion.majorVersion >= 21

/**
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
* virtual thread.
* Create a newThreadPerTaskExecutor with the specified thread factory.
*/
def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
require(threadFactory != null, "threadFactory should not be null.")
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
}
}

/**
* Create a virtual thread factory with the default Virtual Thread executor.
*/
def newVirtualThreadFactory(prefix: String): ThreadFactory = {
require(isSupported, "Virtual thread is not supported.")
Expand All @@ -57,19 +76,38 @@ private[dispatch] object VirtualThreadSupport {
}
}

def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
require(threadFactory != null, "threadFactory should not be null.")
/**
* Create a virtual thread factory with the specified executor as the scheduler of virtual thread.
*/
def newVirtualThreadFactory(prefix: String, executor: ExecutorService): ThreadFactory =
try {
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
val newThreadPerTaskExecutorMethod = lookup.findStatic(
executorsClazz,
"newThreadPerTaskExecutor",
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService]
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
val ofVirtualMethod = classOf[Thread].getDeclaredMethod("ofVirtual")
var builder = ofVirtualMethod.invoke(null)
if (executor != null) {
val clazz = builder.getClass
val field = clazz.getDeclaredField("scheduler")
field.setAccessible(true)
field.set(builder, executor)
}
val nameMethod = ofVirtualClass.getDeclaredMethod("name", classOf[String], classOf[Long])
val factoryMethod = builderClass.getDeclaredMethod("factory")
val zero = java.lang.Long.valueOf(0L)
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", zero)
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
} catch {
case NonFatal(e) =>
// --add-opens java.base/java.lang=ALL-UNNAMED
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
throw new UnsupportedOperationException("Failed to create virtual thread factory", e)
}

object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
private val clazz = ClassLoader.getSystemClassLoader.loadClass("jdk.internal.misc.CarrierThread")
// TODO lookup.findClass is only available in Java 9
private val constructor = clazz.getDeclaredConstructor(classOf[ForkJoinPool])
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
constructor.newInstance(pool).asInstanceOf[ForkJoinWorkerThread]
}
}

Expand All @@ -79,7 +117,7 @@ private[dispatch] object VirtualThreadSupport {
def getVirtualThreadDefaultScheduler: ForkJoinPool =
try {
require(isSupported, "Virtual thread is not supported.")
val clazz = Class.forName("java.lang.VirtualThread")
val clazz = ClassLoader.getSystemClassLoader.loadClass("java.lang.VirtualThread")
val fieldName = "DEFAULT_SCHEDULER"
val field = clazz.getDeclaredField(fieldName)
field.setAccessible(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ final class VirtualizedExecutorService(
require(vtFactory != null, "Virtual thread factory must not be null")
require(loadMetricsProvider != null, "Load metrics provider must not be null")

def this(prefix: String,
underlying: ExecutorService,
loadMetricsProvider: Executor => Boolean,
cascadeShutdown: Boolean) = {
this(VirtualThreadSupport.newVirtualThreadFactory(prefix), underlying, loadMetricsProvider, cascadeShutdown)
}

private val executor = VirtualThreadSupport.newThreadPerTaskExecutor(vtFactory)

override def atFullThrottle(): Boolean = loadMetricsProvider(this)
Expand Down
10 changes: 10 additions & 0 deletions docs/src/main/paradox/dispatchers.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht

When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool.

**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature.
When using virtual threads, all virtual threads will use the same `unparker`, so you may want to
increase the number of `jdk.unparker.maxPoolSize`.

#### Requirements:

1. JDK 21+
2. add options to the JVM:
- `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED`
- `--add-opens=java.base/java.lang=ALL-UNNAMED`
@@@

Another example that uses the "thread-pool-executor":
Expand Down
11 changes: 11 additions & 0 deletions docs/src/main/paradox/typed/dispatchers.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht

When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool.

**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature.
When using virtual threads, all virtual threads will use the same `unparker`, so you may want to
increase the number of `jdk.unparker.maxPoolSize`.

#### Requirements:

1. JDK 21+
2. add options to the JVM:
- `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED`
- `--add-opens=java.base/java.lang=ALL-UNNAMED`

@@@

@@@ note
Expand Down
Loading

0 comments on commit 0d7ad11

Please sign in to comment.