From d4b7fb59ac85603b6701c65704991b8ff15e28f7 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Sun, 19 Jan 2025 12:15:34 +0800 Subject: [PATCH] feat: Add LoadMetrics support for virtual thread executor. --- .../pekko/dispatch/AbstractDispatcher.scala | 36 ++++- .../pekko/dispatch/VirtualThreadSupport.scala | 19 ++- .../dispatch/VirtualizedExecutorService.scala | 124 ++++++++++++++++++ 3 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala b/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala index bb9b301eb3..d73cff521a 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala @@ -27,7 +27,7 @@ import pekko.annotation.InternalStableApi import pekko.dispatch.affinity.AffinityPoolConfigurator import pekko.dispatch.sysmsg._ import pekko.event.EventStream -import pekko.event.Logging.{ Debug, Error, LogEventException } +import pekko.event.Logging.{ emptyMDC, Debug, Error, LogEventException, Warning } import pekko.util.{ unused, Index, Unsafe } import com.typesafe.config.Config @@ -426,11 +426,39 @@ final class VirtualThreadExecutorConfigurator(config: Config, prerequisites: Dis vt } } - case _ => VirtualThreadSupport.newVirtualThreadFactory(prerequisites.settings.name + "-" + id); + case _ => newVirtualThreadFactory(prerequisites.settings.name + "-" + id); } new ExecutorServiceFactory { - import VirtualThreadSupport._ - override def createExecutorService: ExecutorService = newThreadPerTaskExecutor(tf) + override def createExecutorService: ExecutorService with LoadMetrics = { + // try to get the default scheduler of virtual thread + val pool = { + try { + getVirtualThreadDefaultScheduler + } catch { + case NonFatal(e) => + prerequisites.eventStream.publish( + Warning(e, "VirtualThreadExecutorConfigurator", this.getClass, + """ + |Failed to get the default scheduler of virtual thread, so the `LoadMetrics` is not available when using it with `BalancingDispatcher`. + |Add `--add-opens java.base/java.lang=ALL-UNNAMED` to the JVM options to help this. + |""".stripMargin, emptyMDC)) + null + } + } + val loadMetricsProvider: Executor => Boolean = { + if (pool eq null) { + (_: Executor) => true + } else { + (_: Executor) => pool.getActiveThreadCount >= pool.getParallelism + } + } + new VirtualizedExecutorService( + tf, + pool, // the default scheduler of virtual thread + loadMetricsProvider, + cascadeShutdown = false // we don't want to cascade shutdown the default virtual thread scheduler + ) + } } } } diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala index 484e600fc3..955cdc1710 100644 --- a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala +++ b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala @@ -21,7 +21,7 @@ import org.apache.pekko.annotation.InternalApi import org.apache.pekko.util.JavaVersion import java.lang.invoke.{ MethodHandles, MethodType } -import java.util.concurrent.{ ExecutorService, ThreadFactory } +import java.util.concurrent.{ ExecutorService, ForkJoinPool, ThreadFactory } import scala.util.control.NonFatal @InternalApi @@ -73,4 +73,21 @@ private[dispatch] object VirtualThreadSupport { } } + /** + * Try to get the default scheduler of virtual thread. + */ + def getVirtualThreadDefaultScheduler: ForkJoinPool = + try { + require(isSupported, "Virtual thread is not supported.") + val clazz = Class.forName("java.lang.VirtualThread") + val fieldName = "DEFAULT_SCHEDULER" + val field = clazz.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(null).asInstanceOf[ForkJoinPool] + } catch { + case NonFatal(e) => + // --add-opens java.base/java.lang=ALL-UNNAMED + throw new UnsupportedOperationException("Failed to get default scheduler of virtual thread.", e) + } + } diff --git a/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala new file mode 100644 index 0000000000..631c34fdc3 --- /dev/null +++ b/actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.dispatch + +import org.apache.pekko.annotation.InternalApi + +import java.util +import java.util.concurrent.{ Callable, Executor, ExecutorService, Future, ThreadFactory, TimeUnit } + +/** + * A virtualized executor service that creates a new virtual thread for each task. + * Will shut down the underlying executor service when this executor is being shutdown. + * + * INTERNAL API + */ +@InternalApi +final class VirtualizedExecutorService( + vtFactory: ThreadFactory, + underlying: ExecutorService, + loadMetricsProvider: Executor => Boolean, + cascadeShutdown: Boolean) + extends ExecutorService with LoadMetrics { + require(VirtualThreadSupport.isSupported, "Virtual thread is not supported.") + require(vtFactory != null, "Virtual thread factory must not be null") + require(loadMetricsProvider != null, "Load metrics provider must not be null") + + def this(prefix: String, + underlying: ExecutorService, + loadMetricsProvider: Executor => Boolean, + cascadeShutdown: Boolean) = { + this(VirtualThreadSupport.newVirtualThreadFactory(prefix), underlying, loadMetricsProvider, cascadeShutdown) + } + + private val executor = VirtualThreadSupport.newThreadPerTaskExecutor(vtFactory) + + override def atFullThrottle(): Boolean = loadMetricsProvider(this) + + override def shutdown(): Unit = { + executor.shutdown() + if (cascadeShutdown && (underlying ne null)) { + underlying.shutdown() + } + } + + override def shutdownNow(): util.List[Runnable] = { + val r = executor.shutdownNow() + if (cascadeShutdown && (underlying ne null)) { + underlying.shutdownNow() + } + r + } + + override def isShutdown: Boolean = { + if (cascadeShutdown) { + executor.isShutdown && ((underlying eq null) || underlying.isShutdown) + } else { + executor.isShutdown + } + } + + override def isTerminated: Boolean = { + if (cascadeShutdown) { + executor.isTerminated && ((underlying eq null) || underlying.isTerminated) + } else { + executor.isTerminated + } + } + + override def awaitTermination(timeout: Long, unit: TimeUnit): Boolean = { + if (cascadeShutdown) { + executor.awaitTermination(timeout, unit) && ((underlying eq null) || underlying.awaitTermination(timeout, unit)) + } else { + executor.awaitTermination(timeout, unit) + } + } + + override def submit[T](task: Callable[T]): Future[T] = { + executor.submit(task) + } + + override def submit[T](task: Runnable, result: T): Future[T] = { + executor.submit(task, result) + } + + override def submit(task: Runnable): Future[_] = { + executor.submit(task) + } + + override def invokeAll[T](tasks: util.Collection[_ <: Callable[T]]): util.List[Future[T]] = { + executor.invokeAll(tasks) + } + + override def invokeAll[T]( + tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): util.List[Future[T]] = { + executor.invokeAll(tasks, timeout, unit) + } + + override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]]): T = { + executor.invokeAny(tasks) + } + + override def invokeAny[T](tasks: util.Collection[_ <: Callable[T]], timeout: Long, unit: TimeUnit): T = { + executor.invokeAny(tasks, timeout, unit) + } + + override def execute(command: Runnable): Unit = { + executor.execute(command) + } +}