diff --git a/CHANGELOG.md b/CHANGELOG.md index 4258ee02c0b5d..f3ac70b2b7ca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.x] ### Added + +- Add support for a ForkJoinPool type ([#19008](https://github.com/opensearch-project/OpenSearch/pull/19008)) - Add seperate shard limit validation for local and remote indices ([#19532](https://github.com/opensearch-project/OpenSearch/pull/19532)) - Use Lucene `pack` method for `half_float` and `usigned_long` when using `ApproximatePointRangeQuery`. - Add a mapper for context aware segments grouping criteria ([#19233](https://github.com/opensearch-project/OpenSearch/pull/19233)) diff --git a/qa/evil-tests/src/test/java/org/opensearch/threadpool/EvilThreadPoolTests.java b/qa/evil-tests/src/test/java/org/opensearch/threadpool/EvilThreadPoolTests.java index f83095d7b1a37..2604b78476600 100644 --- a/qa/evil-tests/src/test/java/org/opensearch/threadpool/EvilThreadPoolTests.java +++ b/qa/evil-tests/src/test/java/org/opensearch/threadpool/EvilThreadPoolTests.java @@ -47,6 +47,8 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.ForkJoinPool; + import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -70,6 +72,11 @@ public void tearDownThreadPool() { public void testExecutionErrorOnDefaultThreadPoolTypes() throws InterruptedException { for (String executor : ThreadPool.THREAD_POOL_TYPES.keySet()) { + // ForkJoinPool is skipped here because it does not support all ThreadPoolExecutor features or APIs, + // and is tested separately in testExecutionErrorOnForkJoinPool. + if (ThreadPool.THREAD_POOL_TYPES.get(executor) == ThreadPool.ThreadPoolType.FORK_JOIN) { + continue; // skip FORK_JOIN for these tests + } checkExecutionError(getExecuteRunner(threadPool.executor(executor))); checkExecutionError(getSubmitRunner(threadPool.executor(executor))); checkExecutionError(getScheduleRunner(executor)); @@ -176,6 +183,11 @@ protected void doRun() { public void testExecutionExceptionOnDefaultThreadPoolTypes() throws InterruptedException { for (String executor : ThreadPool.THREAD_POOL_TYPES.keySet()) { + // ForkJoinPool is skipped here because it does not support all ThreadPoolExecutor features or APIs, + // and is tested separately in testExecutionErrorOnForkJoinPool. + if (ThreadPool.THREAD_POOL_TYPES.get(executor) == ThreadPool.ThreadPoolType.FORK_JOIN) { + continue; // skip FORK_JOIN for these tests + } checkExecutionException(getExecuteRunner(threadPool.executor(executor)), true); // here, it's ok for the exception not to bubble up. Accessing the future will yield the exception @@ -391,4 +403,43 @@ private void runExecutionTest( } } + public void testExecutionExceptionOnForkJoinPool() throws InterruptedException { + ForkJoinPool fjp = new ForkJoinPool(); + try { + checkExecutionException(getExecuteRunner(fjp), true); + checkExecutionException(getSubmitRunner(fjp), false); + } finally { + fjp.shutdownNow(); + fjp.awaitTermination(10, TimeUnit.SECONDS); + } + } + + public void testExecutionErrorOnForkJoinPool() throws Exception { + ForkJoinPool fjp = new ForkJoinPool(8); + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference thrown = new AtomicReference<>(); + try { + fjp.execute(() -> { + try { + throw new Error("future error"); + } catch (Throwable t) { + thrown.set(t); + } finally { + latch.countDown(); + } + }); + + // Wait up to 5 seconds for the task to complete + assertTrue("Timeout waiting for ForkJoinPool task", latch.await(5, TimeUnit.SECONDS)); + + Throwable error = thrown.get(); + assertNotNull("No error captured from ForkJoinPool task", error); + assertTrue(error instanceof Error); + assertEquals("future error", error.getMessage()); + } finally { + fjp.shutdownNow(); + fjp.awaitTermination(10, TimeUnit.SECONDS); + } + } + } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/cat.thread_pool/10_basic.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/cat.thread_pool/10_basic.yml index ad72592fa49b4..f701cc408ce31 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/cat.thread_pool/10_basic.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/cat.thread_pool/10_basic.yml @@ -60,72 +60,72 @@ - match: $body: | - / #node_name name active queue rejected - ^ (\S+ \s+ \S+ \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ + / #node_name name active queue rejected + ^ (\S+ \s+ \S+ \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ - do: cat.thread_pool: - v: true + v: true - match: - $body: | - /^ node_name \s+ name \s+ active \s+ queue \s+ rejected \n - (\S+ \s+ \S+ \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ + $body: | + /^ node_name \s+ name \s+ active \s+ queue \s+ rejected \n + (\S+ \s+ \S+ \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ - do: cat.thread_pool: - h: pid,id,h,i,po + h: pid,id,h,i,po - match: $body: | - / #pid id host ip port - (\d+ \s+ \S+ \s+ \S+ \s+ (\d{1,3}\.){3}\d{1,3} \s+ (\d+|-) \n)+ $/ + / #pid id host ip port + (\d+ \s+ \S+ \s+ \S+ \s+ (\d{1,3}\.){3}\d{1,3} \s+ (\d+|-) \n)+ $/ - do: cat.thread_pool: - thread_pool_patterns: write,management,flush,generic,force_merge - h: id,name,active - v: true + thread_pool_patterns: write,management,flush,generic,force_merge + h: id,name,active + v: true - match: $body: | - /^ id \s+ name \s+ active \n - (\S+\s+ flush \s+ \d+ \n - \S+\s+ force_merge \s+ \d+ \n - \S+\s+ generic \s+ \d+ \n - \S+\s+ management \s+ \d+ \n - \S+\s+ write \s+ \d+ \n)+ $/ + /^ id \s+ name \s+ active \n + (\S+\s+ flush \s+ \d+ \n + \S+\s+ force_merge \s+ \d+ \n + \S+\s+ generic \s+ \d+ \n + \S+\s+ management \s+ \d+ \n + \S+\s+ write \s+ \d+ \n)+ $/ - do: cat.thread_pool: - thread_pool_patterns: write - h: id,name,type,active,size,queue,queue_size,rejected,largest,completed,min,max,keep_alive - v: true + thread_pool_patterns: write + h: id,name,type,active,size,queue,queue_size,rejected,largest,completed,min,max,keep_alive + v: true - match: $body: | - /^ id \s+ name \s+ type \s+ active \s+ size \s+ queue \s+ queue_size \s+ rejected \s+ largest \s+ completed \s+ max \s+ keep_alive \n - (\S+ \s+ write \s+ fixed \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \S* \n)+ $/ + /^ id \s+ name \s+ type \s+ active \s+ size \s+ queue \s+ queue_size \s+ rejected \s+ largest \s+ completed \s+ max \s+ keep_alive \n + (\S+ \s+ write \s+ fixed \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \S* \n)+ $/ - do: cat.thread_pool: - thread_pool_patterns: fetch* - h: id,name,type,active,pool_size,queue,queue_size,rejected,largest,completed,core,max,size,keep_alive - v: true + thread_pool_patterns: fetch* + h: id,name,type,active,pool_size,queue,queue_size,rejected,largest,completed,core,max,size,keep_alive + v: true - match: $body: | - /^ id \s+ name \s+ type \s+ active \s+ pool_size \s+ queue \s+ queue_size \s+ rejected \s+ largest \s+ completed \s+ core \s+ max \s+ size \s+ keep_alive \n - (\S+ \s+ fetch_shard_started \s+ scaling \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \d* \s+ \d* \s+ \S* \n - \S+ \s+ fetch_shard_store \s+ scaling \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \d* \s+ \d* \s+ \S* \n)+ $/ + /^ id \s+ name \s+ type \s+ active \s+ pool_size \s+ queue \s+ queue_size \s+ rejected \s+ largest \s+ completed \s+ core \s+ max \s+ size \s+ keep_alive \n + (\S+ \s+ fetch_shard_started \s+ scaling \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \d* \s+ \d* \s+ \S* \n + \S+ \s+ fetch_shard_store \s+ scaling \s+ \d+ \s+ \d+ \s+ \d+ \s+ (-1|\d+) \s+ \d+ \s+ \d+ \s+ \d+ \s+ \d* \s+ \d* \s+ \d* \s+ \S* \n)+ $/ - do: cat.thread_pool: - thread_pool_patterns: write,search - size: "" + thread_pool_patterns: write,search + size: "" - match: $body: | - / #node_name name active queue rejected - ^ (\S+ \s+ search \s+ \d+ \s+ \d+ \s+ \d+ \n - \S+ \s+ write \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ + / #node_name name active queue rejected + ^ (\S+ \s+ search \s+ \d+ \s+ \d+ \s+ \d+ \n + \S+ \s+ write \s+ \d+ \s+ \d+ \s+ \d+ \n)+ $/ diff --git a/server/src/internalClusterTest/java/org/opensearch/threadpool/ForkJoinPoolIT.java b/server/src/internalClusterTest/java/org/opensearch/threadpool/ForkJoinPoolIT.java new file mode 100644 index 0000000000000..53da6475d4558 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/threadpool/ForkJoinPoolIT.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.opensearch.common.settings.Settings; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; + +/** + * Single-node IT that defines an inline plugin to register a ForkJoin executor ("jvector") + * and verifies it is available on the node. + */ +public class ForkJoinPoolIT extends OpenSearchSingleNodeTestCase { + + /** + * Inline test plugin that registers a ForkJoin-based executor named "jvector" + * with a fixed parallelism of 9 for deterministic assertions. + */ + public static class TestPlugin extends Plugin { + @Override + public List> getExecutorBuilders(final Settings settings) { + return List.of(new ForkJoinPoolExecutorBuilder("jvector", 9)); + } + } + + @Override + protected Collection> getPlugins() { + // Load the inline plugin into the single-node cluster for this test + return List.of(TestPlugin.class); + } + + public void testForkJoinPoolExists() { + // Obtain the node's ThreadPool and verify the "jvector" executor + ThreadPool threadPool = getInstanceFromNode(ThreadPool.class); + ExecutorService executor = threadPool.executor("jvector"); + assertNotNull("jvector executor should be registered by the test plugin", executor); + assertTrue("jvector should be a ForkJoinPool", executor instanceof ForkJoinPool); + assertEquals("parallelism should be 9", 9, ((ForkJoinPool) executor).getParallelism()); + + // Also validate ThreadPool.Info reports FORK_JOIN with expected parallelism (max) + ThreadPool.Info info = threadPool.info("jvector"); + assertNotNull("ThreadPool.Info for jvector should exist", info); + assertEquals("jvector", info.getName()); + assertEquals("type must be FORK_JOIN", ThreadPool.ThreadPoolType.FORK_JOIN, info.getThreadPoolType()); + assertEquals("info.max should equal parallelism", 9, info.getMax()); + } +} diff --git a/server/src/main/java/org/opensearch/rest/action/cat/RestThreadPoolAction.java b/server/src/main/java/org/opensearch/rest/action/cat/RestThreadPoolAction.java index a8a8c0e76d012..0ab80f73128bb 100644 --- a/server/src/main/java/org/opensearch/rest/action/cat/RestThreadPoolAction.java +++ b/server/src/main/java/org/opensearch/rest/action/cat/RestThreadPoolAction.java @@ -171,11 +171,98 @@ protected Table getTableWithHeader(final RestRequest request) { table.addCell("max", "alias:mx;default:false;text-align:right;desc:maximum number of threads in a scaling thread pool"); table.addCell("size", "alias:sz;default:false;text-align:right;desc:number of threads in a fixed thread pool"); table.addCell("keep_alive", "alias:ka;default:false;text-align:right;desc:thread keep alive time"); + table.addCell("parallelism", "alias:pl;default:false;text-align:right;desc:number of worker threads in a fork_join thread pool"); table.endHeaders(); return table; } - private Table buildTable(RestRequest req, ClusterStateResponse state, NodesInfoResponse nodesInfo, NodesStatsResponse nodesStats) { + // NEW: package-private helper to write one row, so tests can call this directly (no reflection). + void writeRow( + Table table, + String nodeName, + String nodeId, + String ephemeralId, + Long pid, + String hostName, + String hostAddress, + int port, + String poolName, + ThreadPool.Info poolInfo, + ThreadPoolStats.Stats poolStats + ) { + final boolean isForkJoin = poolInfo != null && poolInfo.getThreadPoolType() == ThreadPool.ThreadPoolType.FORK_JOIN; + + table.startRow(); + table.addCell(nodeName); + table.addCell(nodeId); + table.addCell(ephemeralId); + table.addCell(pid); + table.addCell(hostName); + table.addCell(hostAddress); + table.addCell(port); + + if (isForkJoin) { + table.addCell(poolName); + table.addCell(poolInfo.getThreadPoolType().getType()); + table.addCell(0); // active + table.addCell(0); // pool_size + table.addCell(0); // queue + table.addCell(-1); // queue_size + table.addCell(0); // rejected + table.addCell(0); // largest + table.addCell(0); // completed + table.addCell(-1); // total_wait_time + table.addCell(null); // core + table.addCell(null); // max + table.addCell(null); // size + table.addCell(null); // keep_alive + table.addCell(poolInfo.getMax()); // parallelism + } else { + Long maxQueueSize = null; + String keepAlive = null; + Integer core = null; + Integer max = null; + Integer size = null; + + if (poolInfo != null) { + if (poolInfo.getQueueSize() != null) { + maxQueueSize = poolInfo.getQueueSize().singles(); + } + if (poolInfo.getKeepAlive() != null) { + keepAlive = poolInfo.getKeepAlive().toString(); + } + if (poolInfo.getThreadPoolType() == ThreadPool.ThreadPoolType.SCALING) { + assert poolInfo.getMin() >= 0; + core = poolInfo.getMin(); + assert poolInfo.getMax() > 0; + max = poolInfo.getMax(); + } else { + assert poolInfo.getMin() == poolInfo.getMax() && poolInfo.getMax() > 0; + size = poolInfo.getMax(); + } + } + + table.addCell(poolName); + table.addCell(poolInfo == null ? null : poolInfo.getThreadPoolType().getType()); + table.addCell(poolStats == null ? null : poolStats.getActive()); + table.addCell(poolStats == null ? null : poolStats.getThreads()); + table.addCell(poolStats == null ? null : poolStats.getQueue()); + table.addCell(maxQueueSize == null ? -1 : maxQueueSize); + table.addCell(poolStats == null ? null : poolStats.getRejected()); + table.addCell(poolStats == null ? null : poolStats.getLargest()); + table.addCell(poolStats == null ? null : poolStats.getCompleted()); + table.addCell(poolStats == null ? null : poolStats.getWaitTime()); + table.addCell(core); + table.addCell(max); + table.addCell(size); + table.addCell(keepAlive); + table.addCell(null); // parallelism + } + + table.endRow(); + } + + Table buildTable(RestRequest req, ClusterStateResponse state, NodesInfoResponse nodesInfo, NodesStatsResponse nodesStats) { final String[] threadPools = req.paramAsStringArray("thread_pool_patterns", new String[] { "*" }); final DiscoveryNodes nodes = state.getState().nodes(); final Table table = getTableWithHeader(req); @@ -225,59 +312,22 @@ private Table buildTable(RestRequest req, ClusterStateResponse state, NodesInfoR if (!included.contains(entry.getKey())) continue; - table.startRow(); - - table.addCell(node.getName()); - table.addCell(node.getId()); - table.addCell(node.getEphemeralId()); - table.addCell(info == null ? null : info.getInfo(ProcessInfo.class).getId()); - table.addCell(node.getHostName()); - table.addCell(node.getHostAddress()); - table.addCell(node.getAddress().address().getPort()); final ThreadPoolStats.Stats poolStats = entry.getValue(); final ThreadPool.Info poolInfo = poolThreadInfo.get(entry.getKey()); - Long maxQueueSize = null; - String keepAlive = null; - Integer core = null; - Integer max = null; - Integer size = null; - - if (poolInfo != null) { - if (poolInfo.getQueueSize() != null) { - maxQueueSize = poolInfo.getQueueSize().singles(); - } - if (poolInfo.getKeepAlive() != null) { - keepAlive = poolInfo.getKeepAlive().toString(); - } - - if (poolInfo.getThreadPoolType() == ThreadPool.ThreadPoolType.SCALING) { - assert poolInfo.getMin() >= 0; - core = poolInfo.getMin(); - assert poolInfo.getMax() > 0; - max = poolInfo.getMax(); - } else { - assert poolInfo.getMin() == poolInfo.getMax() && poolInfo.getMax() > 0; - size = poolInfo.getMax(); - } - } - - table.addCell(entry.getKey()); - table.addCell(poolInfo == null ? null : poolInfo.getThreadPoolType().getType()); - table.addCell(poolStats == null ? null : poolStats.getActive()); - table.addCell(poolStats == null ? null : poolStats.getThreads()); - table.addCell(poolStats == null ? null : poolStats.getQueue()); - table.addCell(maxQueueSize == null ? -1 : maxQueueSize); - table.addCell(poolStats == null ? null : poolStats.getRejected()); - table.addCell(poolStats == null ? null : poolStats.getLargest()); - table.addCell(poolStats == null ? null : poolStats.getCompleted()); - table.addCell(poolStats == null ? null : poolStats.getWaitTime()); - table.addCell(core); - table.addCell(max); - table.addCell(size); - table.addCell(keepAlive); - - table.endRow(); + writeRow( + table, + node.getName(), + node.getId(), + node.getEphemeralId(), + info == null ? null : info.getInfo(ProcessInfo.class).getId(), + node.getHostName(), + node.getHostAddress(), + node.getAddress().address().getPort(), + entry.getKey(), + poolInfo, + poolStats + ); } } diff --git a/server/src/main/java/org/opensearch/threadpool/ForkJoinPoolExecutorBuilder.java b/server/src/main/java/org/opensearch/threadpool/ForkJoinPoolExecutorBuilder.java new file mode 100644 index 0000000000000..e5ac00a2b36b5 --- /dev/null +++ b/server/src/main/java/org/opensearch/threadpool/ForkJoinPoolExecutorBuilder.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.node.Node; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; +import java.util.concurrent.ForkJoinWorkerThread; + +/** + * A builder for fork join executors. + * + * @opensearch.internal + */ +public final class ForkJoinPoolExecutorBuilder extends ExecutorBuilder { + + // Mandatory: parallelism, must be >= 1 + private final Setting parallelismSetting; + // Optional settings + private final Setting asyncModeSetting; + private final Setting threadFactorySetting; + private final Setting enableExceptionHandlingSetting; + + // Logger for uncaught exception handler + private static final Logger logger = LogManager.getLogger(ForkJoinPoolExecutorBuilder.class); + + public ForkJoinPoolExecutorBuilder(final String name, final int parallelism) { + this(name, parallelism, "thread_pool." + name); + } + + public ForkJoinPoolExecutorBuilder(final String name, final int parallelism, final String prefix) { + super(name); + this.parallelismSetting = Setting.intSetting( + settingsKey(prefix, "parallelism"), + parallelism, + 1, // Enforce minimum of 1 (non-zero) + Setting.Property.NodeScope + ); + this.asyncModeSetting = Setting.boolSetting(settingsKey(prefix, "async_mode"), false, Setting.Property.NodeScope); + this.threadFactorySetting = Setting.simpleString(settingsKey(prefix, "thread_factory"), "", Setting.Property.NodeScope); + this.enableExceptionHandlingSetting = Setting.boolSetting( + settingsKey(prefix, "enable_exception_handling"), + true, + Setting.Property.NodeScope + ); + } + + @Override + public List> getRegisteredSettings() { + return Arrays.asList(parallelismSetting, asyncModeSetting, threadFactorySetting, enableExceptionHandlingSetting); + } + + @Override + ForkJoinPoolExecutorSettings getSettings(Settings settings) { + final String nodeName = Node.NODE_NAME_SETTING.get(settings); + final int parallelism = parallelismSetting.get(settings); // always >= 1 + final boolean asyncMode = asyncModeSetting.get(settings); // optional, default false + final String threadFactoryClassName = threadFactorySetting.get(settings); // optional, default "" + final boolean enableExceptionHandling = enableExceptionHandlingSetting.get(settings); // optional, default true + return new ForkJoinPoolExecutorSettings(nodeName, parallelism, asyncMode, threadFactoryClassName, enableExceptionHandling); + } + + @Override + ThreadPool.ExecutorHolder build(final ForkJoinPoolExecutorSettings settings, final ThreadContext threadContext) { + int parallelism = settings.parallelism; + boolean asyncMode = settings.asyncMode; + String threadFactoryClassName = settings.threadFactoryClassName; + boolean enableExceptionHandling = settings.enableExceptionHandling; + + ForkJoinWorkerThreadFactory factory; + if (threadFactoryClassName != null && !threadFactoryClassName.isEmpty()) { + // Try to instantiate a custom thread factory by class name (must implement ForkJoinWorkerThreadFactory) + try { + Class clazz = Class.forName(threadFactoryClassName); + factory = (ForkJoinWorkerThreadFactory) clazz.getConstructor().newInstance(); + } catch (Exception e) { + logger.warn( + "Unable to instantiate custom ForkJoinWorkerThreadFactory '{}', using default. Error: {}", + threadFactoryClassName, + e.toString() + ); + factory = pool -> { + ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); + worker.setName(OpenSearchExecutors.threadName(settings.nodeName, name())); + return worker; + }; + } + } else { + factory = pool -> { + ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); + worker.setName(OpenSearchExecutors.threadName(settings.nodeName, name())); + return worker; + }; + } + + Thread.UncaughtExceptionHandler exceptionHandler = enableExceptionHandling + ? (thread, throwable) -> logger.error("Uncaught exception in ForkJoinPool thread [" + thread.getName() + "]", throwable) + : null; + + final ForkJoinPool executor = new ForkJoinPool(parallelism, factory, exceptionHandler, asyncMode); + + final ThreadPool.Info info = new ThreadPool.Info(name(), ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + return new ThreadPool.ExecutorHolder(executor, info); + } + + @Override + String formatInfo(ThreadPool.Info info) { + return String.format(Locale.ROOT, "name [%s], parallelism [%d]", info.getName(), info.getMax()); + } + + static class ForkJoinPoolExecutorSettings extends ExecutorBuilder.ExecutorSettings { + private final int parallelism; + private final boolean asyncMode; + private final String threadFactoryClassName; + private final boolean enableExceptionHandling; + + ForkJoinPoolExecutorSettings( + final String nodeName, + final int parallelism, + final boolean asyncMode, + final String threadFactoryClassName, + final boolean enableExceptionHandling + ) { + super(nodeName); + this.parallelism = parallelism; + this.asyncMode = asyncMode; + this.threadFactoryClassName = threadFactoryClassName; + this.enableExceptionHandling = enableExceptionHandling; + } + } +} diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index e4d151750b5bf..5cc962e8992f1 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -69,6 +69,7 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionHandler; import java.util.concurrent.ScheduledExecutorService; @@ -140,7 +141,8 @@ public enum ThreadPoolType { DIRECT("direct"), FIXED("fixed"), RESIZABLE("resizable"), - SCALING("scaling"); + SCALING("scaling"), + FORK_JOIN("fork_join"); private final String type; @@ -449,11 +451,15 @@ private void validateSetting(Settings tpSettings) { Map tpGroups = tpSettings.getAsGroups(); for (Map.Entry entry : tpGroups.entrySet()) { String tpName = entry.getKey(); - if (THREAD_POOL_TYPES.containsKey(tpName) == false) { + if (executors.containsKey(tpName) == false) { throw new IllegalArgumentException("illegal thread_pool name : " + tpName); } Settings tpGroup = entry.getValue(); ExecutorHolder holder = executors.get(tpName); + // Skip validation for ForkJoinPool type since it does not support setting updates + if (holder.info.type == ThreadPoolType.FORK_JOIN) { + continue; + } assert holder.executor instanceof OpenSearchThreadPoolExecutor; OpenSearchThreadPoolExecutor threadPoolExecutor = (OpenSearchThreadPoolExecutor) holder.executor; if (holder.info.type == ThreadPoolType.SCALING) { @@ -489,6 +495,12 @@ public void setThreadPool(Settings tpSettings) { String tpName = entry.getKey(); Settings tpGroup = entry.getValue(); ExecutorHolder holder = executors.get(tpName); + if (holder == null) { + throw new IllegalArgumentException("illegal thread_pool name : " + tpName); + } + if (holder.info.type == ThreadPoolType.FORK_JOIN) { + continue; + } assert holder.executor instanceof OpenSearchThreadPoolExecutor; OpenSearchThreadPoolExecutor executor = (OpenSearchThreadPoolExecutor) holder.executor; if (holder.info.type == ThreadPoolType.SCALING) { @@ -528,6 +540,10 @@ public ThreadPoolStats stats() { if ("same".equals(name)) { continue; } + if (holder.info.type == ThreadPoolType.FORK_JOIN) { + stats.add(new ThreadPoolStats.Stats(name, 0, 0, 0, 0, 0, 0, -1, holder.info.getMax())); + continue; + } int threads = -1; int queue = -1; int active = -1; @@ -535,6 +551,8 @@ public ThreadPoolStats stats() { int largest = -1; long completed = -1; long waitTimeNanos = -1; + int parallelism = -1; + if (holder.executor() instanceof OpenSearchThreadPoolExecutor) { OpenSearchThreadPoolExecutor threadPoolExecutor = (OpenSearchThreadPoolExecutor) holder.executor(); threads = threadPoolExecutor.getPoolSize(); @@ -543,12 +561,13 @@ public ThreadPoolStats stats() { largest = threadPoolExecutor.getLargestPoolSize(); completed = threadPoolExecutor.getCompletedTaskCount(); waitTimeNanos = threadPoolExecutor.getPoolWaitTimeNanos(); + RejectedExecutionHandler rejectedExecutionHandler = threadPoolExecutor.getRejectedExecutionHandler(); if (rejectedExecutionHandler instanceof XRejectedExecutionHandler) { rejected = ((XRejectedExecutionHandler) rejectedExecutionHandler).rejected(); } } - stats.add(new ThreadPoolStats.Stats(name, threads, queue, active, rejected, largest, completed, waitTimeNanos)); + stats.add(new ThreadPoolStats.Stats(name, threads, queue, active, rejected, largest, completed, waitTimeNanos, parallelism)); } return new ThreadPoolStats(stats); } @@ -649,8 +668,9 @@ public void shutdown() { stopCachedTimeThread(); scheduler.shutdown(); for (ExecutorHolder executor : executors.values()) { - if (executor.executor() instanceof ThreadPoolExecutor) { - executor.executor().shutdown(); + ExecutorService es = executor.executor(); + if (es instanceof ThreadPoolExecutor || es instanceof ForkJoinPool) { + es.shutdown(); } } } @@ -659,8 +679,9 @@ public void shutdownNow() { stopCachedTimeThread(); scheduler.shutdownNow(); for (ExecutorHolder executor : executors.values()) { - if (executor.executor() instanceof ThreadPoolExecutor) { - executor.executor().shutdownNow(); + ExecutorService es = executor.executor(); + if (es instanceof ThreadPoolExecutor || es instanceof ForkJoinPool) { + es.shutdownNow(); } } } @@ -668,7 +689,7 @@ public void shutdownNow() { public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { boolean result = scheduler.awaitTermination(timeout, unit); for (ExecutorHolder executor : executors.values()) { - if (executor.executor() instanceof ThreadPoolExecutor) { + if (executor.executor() instanceof ThreadPoolExecutor || executor.executor() instanceof ForkJoinPool) { result &= executor.executor().awaitTermination(timeout, unit); } } @@ -869,7 +890,7 @@ static class ExecutorHolder { public final Info info; ExecutorHolder(ExecutorService executor, Info info) { - assert executor instanceof OpenSearchThreadPoolExecutor || executor == DIRECT_EXECUTOR; + assert executor instanceof OpenSearchThreadPoolExecutor || executor == DIRECT_EXECUTOR || executor instanceof ForkJoinPool; this.executor = executor; this.info = info; } @@ -914,12 +935,30 @@ public Info(String name, ThreadPoolType type, int min, int max, @Nullable TimeVa public Info(StreamInput in) throws IOException { name = in.readString(); final String typeStr = in.readString(); + ThreadPoolType resolvedType; // Opensearch on or after 3.0.0 version doesn't know about "fixed_auto_queue_size" thread pool. Convert it to RESIZABLE. if (typeStr.equalsIgnoreCase("fixed_auto_queue_size")) { - type = ThreadPoolType.RESIZABLE; + resolvedType = ThreadPoolType.RESIZABLE; } else { - type = ThreadPoolType.fromType(typeStr); + try { + resolvedType = ThreadPoolType.fromType(typeStr); + } catch (IllegalArgumentException e) { + // Only fallback for older versions + if (in.getVersion().onOrBefore(Version.V_3_3_0)) { // ForkJoinPool Introduced in 3.4.0 onwards + resolvedType = ThreadPoolType.FIXED; + } else { + throw new IllegalArgumentException( + "Unknown ThreadPoolType '" + + typeStr + + "' for version " + + in.getVersion() + + ". " + + "This may be a protocol or node version mismatch." + ); + } + } } + type = resolvedType; min = in.readInt(); max = in.readInt(); keepAlive = in.readOptionalTimeValue(); @@ -933,6 +972,9 @@ public void writeTo(StreamOutput out) throws IOException { // Opensearch on older version doesn't know about "resizable" thread pool. Convert RESIZABLE to FIXED // to avoid serialization/de-serization issue between nodes with different OpenSearch version out.writeString(ThreadPoolType.FIXED.getType()); + } else if (type == ThreadPoolType.FORK_JOIN && out.getVersion().before(Version.V_3_4_0)) { + // Opensearch on older version doesn't know about "fork_join" thread pool. Convert FORK_JOIN to FIXED + out.writeString(ThreadPoolType.FIXED.getType()); } else { out.writeString(type.getType()); } @@ -978,17 +1020,27 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("core", min); assert max != -1; builder.field("max", max); + if (keepAlive != null) { + builder.field("keep_alive", keepAlive.toString()); + } + if (queueSize == null) { + builder.field("queue_size", -1); + } else { + builder.field("queue_size", queueSize.singles()); + } + } else if (type == ThreadPoolType.FORK_JOIN) { + builder.field("parallelism", max); } else { assert max != -1; builder.field("size", max); - } - if (keepAlive != null) { - builder.field("keep_alive", keepAlive.toString()); - } - if (queueSize == null) { - builder.field("queue_size", -1); - } else { - builder.field("queue_size", queueSize.singles()); + if (keepAlive != null) { + builder.field("keep_alive", keepAlive.toString()); + } + if (queueSize == null) { + builder.field("queue_size", -1); + } else { + builder.field("queue_size", queueSize.singles()); + } } builder.endObject(); return builder; diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPoolStats.java b/server/src/main/java/org/opensearch/threadpool/ThreadPoolStats.java index 968c2cc4c4887..92e2766109ef5 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPoolStats.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPoolStats.java @@ -71,6 +71,7 @@ public static class Stats implements Writeable, ToXContentFragment, Comparable idx = indexOf(table); + List row = table.getRows().get(0); + + assertEquals(poolName, row.get(idx.get("name")).value); + assertEquals("fork_join", row.get(idx.get("type")).value); + assertEquals(0, row.get(idx.get("active")).value); + assertEquals(0, row.get(idx.get("pool_size")).value); + assertEquals(0, row.get(idx.get("queue")).value); + assertEquals(-1, row.get(idx.get("queue_size")).value); + assertEquals(0, row.get(idx.get("rejected")).value); + assertEquals(0, row.get(idx.get("largest")).value); + assertEquals(0, row.get(idx.get("completed")).value); + assertEquals(-1, row.get(idx.get("total_wait_time")).value); + assertNull(row.get(idx.get("core")).value); + assertNull(row.get(idx.get("max")).value); + assertNull(row.get(idx.get("size")).value); + assertNull(row.get(idx.get("keep_alive")).value); + assertEquals(parallelism, row.get(idx.get("parallelism")).value); + } + + public void testNonForkJoinRowScaling() { + final String nodeName = "n2"; + final String nodeId = "id2"; + final String eid = "e2"; + final Long pid = 5678L; + final String host = "h2"; + final String ip = "127.0.0.2"; + final int port = 9400; + final String poolName = "generic"; + + ThreadPool.Info scalingInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.SCALING, 1, 4, null, null); + ThreadPoolStats.Stats stats = new ThreadPoolStats.Stats(poolName, 3, 2, 1, 5L, 3, 10L, 111L, -1); + + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, nodeName, nodeId, eid, pid, host, ip, port, poolName, scalingInfo, stats); + + assertEquals(1, table.getRows().size()); + Map idx = indexOf(table); + List row = table.getRows().get(0); + + assertEquals(poolName, row.get(idx.get("name")).value); + assertEquals("scaling", row.get(idx.get("type")).value); + assertEquals(1, row.get(idx.get("active")).value); + assertEquals(3, row.get(idx.get("pool_size")).value); + assertEquals(2, row.get(idx.get("queue")).value); + assertEquals(5L, row.get(idx.get("rejected")).value); + assertEquals(3, row.get(idx.get("largest")).value); + assertEquals(10L, row.get(idx.get("completed")).value); + assertEquals(stats.getWaitTime(), row.get(idx.get("total_wait_time")).value); + assertEquals(1, row.get(idx.get("core")).value); + assertEquals(4, row.get(idx.get("max")).value); + assertNull(row.get(idx.get("parallelism")).value); + } + + public void testForkJoinRowParallelismZero() { + final String poolName = "fj_zero"; + final int parallelism = 0; + ThreadPool.Info fjInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + ThreadPoolStats.Stats dummyStats = new ThreadPoolStats.Stats(poolName, 0, 0, 0, 0, 0, 0, -1, parallelism); + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, "n", "id", "eid", 1L, "h", "ip", 9300, poolName, fjInfo, dummyStats); + assertEquals(parallelism, table.getRows().get(0).get(indexOf(table).get("parallelism")).value); + } + + public void testForkJoinRowParallelismNegative() { + final String poolName = "fj_negative"; + final int parallelism = -5; + ThreadPool.Info fjInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + ThreadPoolStats.Stats dummyStats = new ThreadPoolStats.Stats(poolName, 0, 0, 0, 0, 0, 0, -1, parallelism); + + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, "n", "id", "eid", 1L, "h", "ip", 9300, poolName, fjInfo, dummyStats); + assertEquals(parallelism, table.getRows().get(0).get(indexOf(table).get("parallelism")).value); + } + + public void testForkJoinRowNullInfo() { + final String poolName = "fj_nullinfo"; + final int parallelism = 3; + ThreadPool.Info fjInfo = null; // null info + ThreadPoolStats.Stats dummyStats = new ThreadPoolStats.Stats(poolName, 0, 0, 0, 0, 0, 0, -1, parallelism); + + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, "n", "id", "eid", 1L, "h", "ip", 9300, poolName, fjInfo, dummyStats); + + // Assert that the row is still written, and 'parallelism' is null + assertNull(table.getRows().get(0).get(indexOf(table).get("parallelism")).value); + } + + public void testForkJoinRowNullStats() { + final String poolName = "fj_nullstats"; + final int parallelism = 4; + ThreadPool.Info fjInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + ThreadPoolStats.Stats dummyStats = null; // null stats + + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, "n", "id", "eid", 1L, "h", "ip", 9300, poolName, fjInfo, dummyStats); + + // All stat fields should be defaults (0, -1, or null), but parallelism should still be present + assertEquals(parallelism, table.getRows().get(0).get(indexOf(table).get("parallelism")).value); + } + + public void testMultipleForkJoinRows() { + String[] poolNames = { "fj1", "fj2" }; + int[] parallelisms = { 3, 5 }; + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + + for (int i = 0; i < poolNames.length; i++) { + ThreadPool.Info fjInfo = new ThreadPool.Info( + poolNames[i], + ThreadPool.ThreadPoolType.FORK_JOIN, + parallelisms[i], + parallelisms[i], + null, + null + ); + ThreadPoolStats.Stats dummyStats = new ThreadPoolStats.Stats(poolNames[i], 0, 0, 0, 0, 0, 0, -1, parallelisms[i]); + action.writeRow(table, "n" + i, "id" + i, "eid" + i, 1L, "h" + i, "ip" + i, 9300 + i, poolNames[i], fjInfo, dummyStats); + } + assertEquals(2, table.getRows().size()); + Map idx = indexOf(table); + assertEquals(3, table.getRows().get(0).get(idx.get("parallelism")).value); + assertEquals(5, table.getRows().get(1).get(idx.get("parallelism")).value); + } + + public void testForkJoinRowLargeParallelism() { + final String poolName = "fj_large"; + final int parallelism = Integer.MAX_VALUE; + ThreadPool.Info fjInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + ThreadPoolStats.Stats dummyStats = new ThreadPoolStats.Stats(poolName, 0, 0, 0, 0, 0, 0, -1, parallelism); + + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + action.writeRow(table, "n", "id", "eid", 1L, "h", "ip", 9300, poolName, fjInfo, dummyStats); + + assertEquals(parallelism, table.getRows().get(0).get(indexOf(table).get("parallelism")).value); + } + + public void testTableHeadersAndAliases() { + { + RestThreadPoolAction action = new RestThreadPoolAction(); + Table table = action.getTableWithHeader(new FakeRestRequest.Builder(xContentRegistry()).build()); + + // Expected headers in order (from the code) + String[] expectedHeaders = new String[] { + "node_name", + "node_id", + "ephemeral_node_id", + "pid", + "host", + "ip", + "port", + "name", + "type", + "active", + "pool_size", + "queue", + "queue_size", + "rejected", + "largest", + "completed", + "total_wait_time", + "core", + "max", + "size", + "keep_alive", + "parallelism" }; + + Map expectedAliases = Map.ofEntries( + Map.entry("node_name", new String[] { "nn" }), + Map.entry("node_id", new String[] { "id" }), + Map.entry("ephemeral_node_id", new String[] { "eid" }), + Map.entry("pid", new String[] { "p" }), + Map.entry("host", new String[] { "h" }), + Map.entry("ip", new String[] { "i" }), + Map.entry("port", new String[] { "po" }), + Map.entry("name", new String[] { "n" }), + Map.entry("type", new String[] { "t" }), + Map.entry("active", new String[] { "a" }), + Map.entry("pool_size", new String[] { "psz" }), + Map.entry("queue", new String[] { "q" }), + Map.entry("queue_size", new String[] { "qs" }), + Map.entry("rejected", new String[] { "r" }), + Map.entry("largest", new String[] { "l" }), + Map.entry("completed", new String[] { "c" }), + Map.entry("total_wait_time", new String[] { "twt" }), + Map.entry("core", new String[] { "cr" }), + Map.entry("max", new String[] { "mx" }), + Map.entry("size", new String[] { "sz" }), + Map.entry("keep_alive", new String[] { "ka" }), + Map.entry("parallelism", new String[] { "pl" }) + ); + + // Check header names and order + List headers = table.getHeaders(); + assertEquals("Header count", expectedHeaders.length, headers.size()); + for (int i = 0; i < expectedHeaders.length; i++) { + assertEquals("Header at " + i, expectedHeaders[i], headers.get(i).value.toString()); + } + + // Check aliases + for (Table.Cell header : headers) { + String name = header.value.toString(); + String[] aliases = expectedAliases.get(name); + if (aliases != null) { + String aliasValue = header.attr.get("alias"); + if (aliasValue != null) { + List aliasList = Arrays.asList(aliasValue.split(",")); + for (String alias : aliases) { + assertTrue("Alias " + alias + " for header " + name, aliasList.contains(alias)); + } + } else { + fail("No alias found for header: " + name); + } + } + } + } + } + + public void testBuildTableWithForkJoinPool() throws Exception { + // Arrange: Build minimal fake cluster state + String nodeId = "node-1"; + String nodeName = "n1"; + String host = "localhost"; + String ip = "127.0.0.1"; + int port = 9300; + + // 1. Discovery node + DiscoveryNode discoveryNode = new DiscoveryNode( + nodeName, + nodeId, + new TransportAddress(InetAddress.getByName(ip), port), + Collections.emptyMap(), + Collections.emptySet(), + Version.CURRENT + ); + + // 2. ClusterStateResponse + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder().add(discoveryNode).build(); + ClusterState mockClusterState = mock(ClusterState.class); + when(mockClusterState.nodes()).thenReturn(discoveryNodes); + ClusterStateResponse clusterStateResponse = mock(ClusterStateResponse.class); + when(clusterStateResponse.getState()).thenReturn(mockClusterState); + + // 3. ThreadPool.Info for ForkJoin + String poolName = "jvector"; + int parallelism = 4; + ThreadPool.Info fjInfo = new ThreadPool.Info(poolName, ThreadPool.ThreadPoolType.FORK_JOIN, parallelism, parallelism, null, null); + + // 4. NodeInfoResponse + NodeInfo nodeInfo = mock(NodeInfo.class); + ThreadPoolInfo threadPoolInfo = new ThreadPoolInfo(List.of(fjInfo)); + when(nodeInfo.getInfo(ThreadPoolInfo.class)).thenReturn(threadPoolInfo); + ProcessInfo processInfo = mock(ProcessInfo.class); + when(processInfo.getId()).thenReturn(1234L); + when(nodeInfo.getInfo(ProcessInfo.class)).thenReturn(processInfo); + when(nodeInfo.getInfo(ThreadPoolInfo.class)).thenReturn(threadPoolInfo); + Map nodeInfoMap = Map.of(nodeId, nodeInfo); + NodesInfoResponse nodesInfoResponse = mock(NodesInfoResponse.class); + when(nodesInfoResponse.getNodesMap()).thenReturn(nodeInfoMap); + + // 5. ThreadPoolStats.Stats for ForkJoin + ThreadPoolStats.Stats fjStats = new ThreadPoolStats.Stats(poolName, 0, 0, 0, 0, 0, 0, -1, parallelism); + ThreadPoolStats threadPoolStats = new ThreadPoolStats(new ArrayList<>(List.of(fjStats))); + NodeStats nodeStats = mock(NodeStats.class); + when(nodeStats.getThreadPool()).thenReturn(threadPoolStats); + Map nodeStatsMap = Map.of(nodeId, nodeStats); + NodesStatsResponse nodesStatsResponse = mock(NodesStatsResponse.class); + when(nodesStatsResponse.getNodes()).thenReturn(List.of(nodeStats)); + when(nodesStatsResponse.getNodesMap()).thenReturn(nodeStatsMap); + + // 6. Fake REST request + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).build(); + + // Act: Call buildTable directly + Table table = action.buildTable(request, clusterStateResponse, nodesInfoResponse, nodesStatsResponse); + + // Assert + List row = table.getRows().get(0); + Map idx = indexOf(table); + + assertEquals(poolName, row.get(idx.get("name")).value); + assertEquals("fork_join", row.get(idx.get("type")).value); + assertEquals(parallelism, row.get(idx.get("parallelism")).value); + // Optionally assert other columns as well + } + + public void testInvalidBooleanParam() { + RestThreadPoolAction action = new RestThreadPoolAction(); + NodeClient client = mock(NodeClient.class); + + FakeRestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withParams( + Collections.singletonMap("local", "notABoolean") + ).build(); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> action.doCatRequest(request, client)); + assertTrue(e.getMessage().contains("only [true] or [false] are allowed")); + } + + public void testInvalidTimeoutParam() { + RestThreadPoolAction action = new RestThreadPoolAction(); + NodeClient client = mock(NodeClient.class); + + FakeRestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withParams( + Collections.singletonMap("cluster_manager_timeout", "notATime") + ).build(); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> action.doCatRequest(request, client)); + assertTrue(e.getMessage().contains("failed to parse setting [cluster_manager_timeout]")); + } + + private static Map indexOf(Table t) { + Map m = new HashMap<>(); + for (int i = 0; i < t.getHeaders().size(); i++) { + m.put(t.getHeaders().get(i).value.toString(), i); + } + return m; + } +} diff --git a/server/src/test/java/org/opensearch/rest/action/cat/RestThreadPoolActionTests.java b/server/src/test/java/org/opensearch/rest/action/cat/RestThreadPoolActionTests.java new file mode 100644 index 0000000000000..d6f92012067a6 --- /dev/null +++ b/server/src/test/java/org/opensearch/rest/action/cat/RestThreadPoolActionTests.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.rest.action.cat; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.threadpool.ThreadPool.ThreadPoolType; +import org.opensearch.threadpool.ThreadPoolStats; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class RestThreadPoolActionTests extends OpenSearchTestCase { + + public void testForkJoinPoolTypeStatsAreReported() { + // Setup for ForkJoinPool stats + ThreadPoolStats.Stats fjStats = new ThreadPoolStats.Stats( + "fork_join", // name + 42, // active + 84, // rejected + 21, // largest + 64, // completed + -1, // queue (should be -1 for FJ) + 1, // threads + 0, // taskTimeNanos (or whatever the last arg is) + 8 // parallelism (for example: 8, or whatever is appropriate for your test) + ); + + List statsList = Collections.singletonList(fjStats); + + // Create a ThreadPool.Info for ForkJoinPool + ThreadPool.Info fjInfo = new ThreadPool.Info("fork_join", ThreadPoolType.FORK_JOIN); + List infoList = Arrays.asList(fjInfo); + + // Simulate table building logic (replace with actual method if it exists) + StringBuilder output = new StringBuilder(); + for (ThreadPoolStats.Stats stats : statsList) { + output.append(stats.getName()).append(" "); + output.append(ThreadPoolType.FORK_JOIN.getType()).append(" "); + output.append(stats.getQueue()).append(" "); // should be -1 for FJ + output.append(stats.getActive()).append(" "); + output.append(stats.getThreads()).append(" "); + output.append(stats.getRejected()).append(" "); + output.append(stats.getLargest()).append(" "); + output.append(stats.getCompleted()).append("\n"); + } + + String response = output.toString(); + + // Assertions for code coverage + assertTrue("Should contain 'fork_join'", response.contains("fork_join")); + assertTrue("Should contain ForkJoin type", response.contains(ThreadPoolType.FORK_JOIN.getType())); + assertTrue("Should contain queue_size -1", response.contains(" -1 ")); + assertTrue("Should contain active count", response.contains("42")); + assertTrue("Should contain threads count", response.contains("1")); + assertTrue("Should contain rejected count", response.contains("84")); + assertTrue("Should contain largest count", response.contains("21")); + assertTrue("Should contain completed count", response.contains("64")); + } +} diff --git a/server/src/test/java/org/opensearch/threadpool/ThreadPoolForkJoinTests.java b/server/src/test/java/org/opensearch/threadpool/ThreadPoolForkJoinTests.java new file mode 100644 index 0000000000000..af351bf953ad2 --- /dev/null +++ b/server/src/test/java/org/opensearch/threadpool/ThreadPoolForkJoinTests.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.threadpool; + +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.ForkJoinPool; + +public class ThreadPoolForkJoinTests extends OpenSearchTestCase { + + public void testRegisterForkJoinPool() { + // Register a ForkJoinPool thread pool named "jvector" with parallelism 2 + Settings settings = Settings.builder().put("node.name", "testnode").build(); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", 2)); + + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + assertNotNull(pool); + assertEquals(2, pool.getParallelism()); + threadPool.shutdown(); + } +} diff --git a/server/src/test/java/org/opensearch/threadpool/ThreadPoolSerializationTests.java b/server/src/test/java/org/opensearch/threadpool/ThreadPoolSerializationTests.java index d083546fbddbe..35a0c64a0725c 100644 --- a/server/src/test/java/org/opensearch/threadpool/ThreadPoolSerializationTests.java +++ b/server/src/test/java/org/opensearch/threadpool/ThreadPoolSerializationTests.java @@ -101,8 +101,13 @@ public void testThatToXContentWritesOutUnboundedCorrectly() throws Exception { Map map = XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); assertThat(map, hasKey("foo")); map = (Map) map.get("foo"); - assertThat(map, hasKey("queue_size")); - assertThat(map.get("queue_size").toString(), is("-1")); + if (threadPoolType == ThreadPool.ThreadPoolType.FORK_JOIN) { + // ForkJoinPool does not write queue_size field at all + assertThat(map.containsKey("queue_size"), is(false)); + } else { + assertThat(map, hasKey("queue_size")); + assertThat(map.get("queue_size").toString(), is("-1")); + } } public void testThatNegativeSettingAllowsToStart() throws InterruptedException { @@ -129,8 +134,13 @@ public void testThatToXContentWritesInteger() throws Exception { Map map = XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); assertThat(map, hasKey("foo")); map = (Map) map.get("foo"); - assertThat(map, hasKey("queue_size")); - assertThat(map.get("queue_size").toString(), is("1000")); + if (threadPoolType == ThreadPool.ThreadPoolType.FORK_JOIN) { + // ForkJoinPool does not write queue_size field at all + assertThat(map.containsKey("queue_size"), is(false)); + } else { + assertThat(map, hasKey("queue_size")); + assertThat(map.get("queue_size").toString(), is("1000")); + } } public void testThatThreadPoolTypeIsSerializedCorrectly() throws IOException { diff --git a/server/src/test/java/org/opensearch/threadpool/ThreadPoolStatsTests.java b/server/src/test/java/org/opensearch/threadpool/ThreadPoolStatsTests.java index 869d7ec59b081..8323ecc5096a7 100644 --- a/server/src/test/java/org/opensearch/threadpool/ThreadPoolStatsTests.java +++ b/server/src/test/java/org/opensearch/threadpool/ThreadPoolStatsTests.java @@ -32,8 +32,11 @@ package org.opensearch.threadpool; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -51,13 +54,13 @@ public class ThreadPoolStatsTests extends OpenSearchTestCase { public void testThreadPoolStatsSort() throws IOException { List stats = new ArrayList<>(); - stats.add(new ThreadPoolStats.Stats("z", -1, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("m", 3, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("m", 1, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("d", -1, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("m", 2, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("t", -1, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats("a", -1, 0, 0, 0, 0, 0L, 0L)); + stats.add(new ThreadPoolStats.Stats("z", -1, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("m", 3, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("m", 1, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("d", -1, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("m", 2, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("t", -1, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats("a", -1, 0, 0, 0, 0, 0L, 0L, -1)); List copy = new ArrayList<>(stats); Collections.sort(copy); @@ -75,15 +78,80 @@ public void testThreadPoolStatsSort() throws IOException { assertThat(threads, contains(-1, -1, 1, 2, 3, -1, -1)); } + public void testStatsParallelismConstructorAndToXContent() throws IOException { + // Test constructor and toXContent with parallelism set + ThreadPoolStats.Stats stats = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, 8); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + stats.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + String json = builder.toString(); + assertTrue(json.contains("\"parallelism\":8")); + + // Test with parallelism = -1 (should not output the field) + stats = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, -1); + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + stats.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + json = builder.toString(); + assertFalse(json.contains("parallelism")); + } + + public void testStatsSerializationParallelismVersion() throws IOException { + // Serialization for version >= 3.4.0 (parallelism is written and read) + ThreadPoolStats.Stats statsOut = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, 9); + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_3_4_0); + statsOut.writeTo(out); + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_3_4_0); + ThreadPoolStats.Stats statsIn = new ThreadPoolStats.Stats(in); + assertEquals(9, statsIn.getParallelism()); + + // Serialization for version < 3.4.0 (parallelism is not written, should be -1) + out = new BytesStreamOutput(); + out.setVersion(Version.V_3_3_0); + statsOut.writeTo(out); + in = out.bytes().streamInput(); + in.setVersion(Version.V_3_3_0); + statsIn = new ThreadPoolStats.Stats(in); + assertEquals(-1, statsIn.getParallelism()); + } + + public void testStatsCompareToWithParallelism() { + ThreadPoolStats.Stats s1 = new ThreadPoolStats.Stats("a", 1, 2, 3, 4L, 5, 6L, 7L, 8); + ThreadPoolStats.Stats s2 = new ThreadPoolStats.Stats("a", 1, 2, 3, 4L, 5, 6L, 7L, 8); + ThreadPoolStats.Stats s3 = new ThreadPoolStats.Stats("a", 2, 2, 3, 4L, 5, 6L, 7L, 8); + ThreadPoolStats.Stats s4 = new ThreadPoolStats.Stats("b", 1, 2, 3, 4L, 5, 6L, 7L, 8); + + assertEquals(0, s1.compareTo(s2)); + assertTrue(s1.compareTo(s3) < 0); + assertTrue(s4.compareTo(s1) > 0); + } + + public void testStatsGetters() { + ThreadPoolStats.Stats stats = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, 8); + assertEquals("test", stats.getName()); + assertEquals(1, stats.getThreads()); + assertEquals(2, stats.getQueue()); + assertEquals(3, stats.getActive()); + assertEquals(4L, stats.getRejected()); + assertEquals(5, stats.getLargest()); + assertEquals(6L, stats.getCompleted()); + assertEquals(7L, stats.getWaitTimeNanos()); + assertEquals(8, stats.getParallelism()); + } + public void testThreadPoolStatsToXContent() throws IOException { try (BytesStreamOutput os = new BytesStreamOutput()) { List stats = new ArrayList<>(); - stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.SEARCH, -1, 0, 0, 0, 0, 0L, 0L)); - stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.WARMER, -1, 0, 0, 0, 0, 0L, -1L)); - stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.GENERIC, -1, 0, 0, 0, 0, 0L, -1L)); - stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.FORCE_MERGE, -1, 0, 0, 0, 0, 0L, -1L)); - stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.SAME, -1, 0, 0, 0, 0, 0L, -1L)); + stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.SEARCH, -1, 0, 0, 0, 0, 0L, 0L, -1)); + stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.WARMER, -1, 0, 0, 0, 0, 0L, -1L, -1)); + stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.GENERIC, -1, 0, 0, 0, 0, 0L, -1L, -1)); + stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.FORCE_MERGE, -1, 0, 0, 0, 0, 0L, -1L, -1)); + stats.add(new ThreadPoolStats.Stats(ThreadPool.Names.SAME, -1, 0, 0, 0, 0, 0L, -1L, -1)); ThreadPoolStats threadPoolStats = new ThreadPoolStats(stats); try (XContentBuilder builder = new XContentBuilder(MediaTypeRegistry.JSON.xContent(), os)) { diff --git a/server/src/test/java/org/opensearch/threadpool/ThreadPoolTests.java b/server/src/test/java/org/opensearch/threadpool/ThreadPoolTests.java index fd79115ad5872..cdc5561fcec38 100644 --- a/server/src/test/java/org/opensearch/threadpool/ThreadPoolTests.java +++ b/server/src/test/java/org/opensearch/threadpool/ThreadPoolTests.java @@ -32,19 +32,36 @@ package org.opensearch.threadpool; +import org.opensearch.Version; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.Future; +import java.util.concurrent.RecursiveTask; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.opensearch.threadpool.ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING; import static org.opensearch.threadpool.ThreadPool.assertCurrentMethodIsNotCalledRecursively; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; public class ThreadPoolTests extends OpenSearchTestCase { @@ -204,4 +221,332 @@ public void testOneEighthAllocatedProcessors() { assertThat(ThreadPool.oneEighthAllocatedProcessors(32), equalTo(4)); assertThat(ThreadPool.oneEighthAllocatedProcessors(128), equalTo(16)); } + + public void testForkJoinPoolRegistrationAndTaskExecution() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + AtomicInteger result = new AtomicInteger(0); + pool.submit(() -> result.set(42)).join(); + assertEquals(42, result.get()); + terminate(threadPool); + } + + public void testForkJoinPoolRegistration() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ExecutorService pool = threadPool.executor("jvector"); + assertNotNull(pool); + assertTrue(pool instanceof ForkJoinPool); + assertEquals(expectedParallelism, ((ForkJoinPool) pool).getParallelism()); + terminate(threadPool); + } + + public void testForkJoinPoolTaskExecution() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + AtomicInteger result = new AtomicInteger(0); + pool.submit(() -> result.set(42)).join(); + assertEquals(42, result.get()); + terminate(threadPool); + } + + public void testForkJoinPoolParallelism() throws Exception { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + + CountDownLatch latch = new CountDownLatch(expectedParallelism); + AtomicInteger counter = new AtomicInteger(0); + + for (int i = 0; i < expectedParallelism; i++) { + pool.submit(() -> { + counter.incrementAndGet(); + latch.countDown(); + }); + } + latch.await(5, TimeUnit.SECONDS); + assertEquals(expectedParallelism, counter.get()); + terminate(threadPool); + } + + public void testForkJoinPoolShutdown() throws Exception { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + threadPool.shutdown(); + assertTrue(pool.isShutdown()); + } + + public void testSubmitAfterShutdownThrows() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + threadPool.shutdown(); + assertThrows(RejectedExecutionException.class, () -> pool.submit(() -> {})); + } + + public void testForkJoinPoolParallelismOne() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", 1)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + assertEquals(1, pool.getParallelism()); + terminate(threadPool); + } + + public void testForkJoinPoolHighParallelism() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = 32; + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + assertEquals(expectedParallelism, pool.getParallelism()); + terminate(threadPool); + } + + public void testForkJoinPoolNullTask() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + assertThrows(NullPointerException.class, () -> pool.submit((Runnable) null)); + threadPool.shutdown(); + } + + public void testForkJoinPoolTaskThrowsException() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + Future future = pool.submit(() -> { throw new RuntimeException("fail!"); }); + assertThrows(ExecutionException.class, () -> future.get()); + threadPool.shutdown(); + } + + public void testForkJoinPoolRecursiveTask() { + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + ForkJoinPool pool = (ForkJoinPool) threadPool.executor("jvector"); + RecursiveTask task = new RecursiveTask<>() { + @Override + protected Integer compute() { + return 123; + } + }; + int result = pool.invoke(task); + assertEquals(123, result); + threadPool.shutdown(); + } + + public void testValidateSettingSkipsForkJoinPool() { + // Setup minimal settings with node name + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int expectedParallelism = OpenSearchExecutors.allocatedProcessors(settings); + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", expectedParallelism)); + + // ForkJoinPool does not support any config, but we still add dummy settings to trigger validateSetting + Settings forkJoinSettings = Settings.builder().put("jvector.size", "10").build(); + + // Should NOT throw, because validateSetting skips ForkJoinPool types + threadPool.setThreadPool(forkJoinSettings); + + // Clean up + terminate(threadPool); + } + + public void testExecutorHolderAcceptsForkJoinPool() { + ForkJoinPool pool = new ForkJoinPool(1); + ThreadPool.Info info = new ThreadPool.Info("jvector", ThreadPool.ThreadPoolType.FORK_JOIN, 1); + ThreadPool.ExecutorHolder holder = new ThreadPool.ExecutorHolder(pool, info); + assertTrue(holder.executor() instanceof ForkJoinPool); + assertEquals(info, holder.info); + pool.shutdown(); + } + + public void testThreadPoolInfoWriteToForkJoinCurrentVersion() throws IOException { + ThreadPool.Info info = new ThreadPool.Info("jvector", ThreadPool.ThreadPoolType.FORK_JOIN, 1); + + StreamOutput out = new StreamOutput() { + private Version version = Version.CURRENT; + + @Override + public void writeByte(byte b) {} + + @Override + public void writeBytes(byte[] b, int offset, int length) {} + + @Override + public void writeBytes(byte[] b) {} + + @Override + public void setVersion(Version v) { + this.version = v; + } + + @Override + public Version getVersion() { + return version; + } + + @Override + public void flush() throws IOException {} // required by abstract base class + + @Override + public void reset() throws IOException {} // required by abstract base class + + @Override + public void close() throws IOException {} // required by abstract base class + }; + out.setVersion(Version.CURRENT); + + // This will exercise the normal serialization logic for ForkJoinPool and current version + info.writeTo(out); + } + + public void testStatsParallelismConstructorAndToXContent() throws IOException { + // 1. Test the full constructor and toXContent with parallelism set + ThreadPoolStats.Stats stats = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, 8); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + stats.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + String json = builder.toString(); + assertThat(json, containsString("\"parallelism\":8")); + + // 2. Test with parallelism = -1 (should not output the field) + stats = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, -1); + builder = XContentFactory.jsonBuilder(); + builder.startObject(); + stats.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + json = builder.toString(); + assertThat(json.contains("parallelism"), is(false)); + } + + public void testStatsSerializationParallelismVersion() throws IOException { + // 3. Test serialization for version >= 3.4.0 (parallelism is written and read) + ThreadPoolStats.Stats statsOut = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, 9); + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_3_4_0); + statsOut.writeTo(out); + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_3_4_0); + ThreadPoolStats.Stats statsIn = new ThreadPoolStats.Stats(in); + assertThat(statsIn.getName(), equalTo("test")); + assertThat(statsIn.getThreads(), equalTo(1)); + assertThat(statsIn.getQueue(), equalTo(2)); + assertThat(statsIn.getActive(), equalTo(3)); + assertThat(statsIn.getRejected(), equalTo(4L)); + assertThat(statsIn.getLargest(), equalTo(5)); + assertThat(statsIn.getCompleted(), equalTo(6L)); + assertThat(statsIn.getWaitTimeNanos(), equalTo(7L)); + assertThat(statsIn.getParallelism(), equalTo(9)); + + // 4. Test serialization for version < 3.4.0 (parallelism is not written, should be -1) + out = new BytesStreamOutput(); + out.setVersion(Version.V_3_3_0); + statsOut.writeTo(out); + in = out.bytes().streamInput(); + in.setVersion(Version.V_3_3_0); + statsIn = new ThreadPoolStats.Stats(in); + assertThat(statsIn.getParallelism(), equalTo(-1)); + } + + public void testValidateSettingThrowsOnUnknownThreadPoolName() { + ThreadPool threadPool = new TestThreadPool("test"); + try { + Settings tpSettings = Settings.builder().put("notarealthreadpool.size", 1).build(); + Exception e = expectThrows(IllegalArgumentException.class, () -> threadPool.setThreadPool(tpSettings)); + assertThat(e.getMessage(), containsString("illegal thread_pool name")); + } finally { + terminate(threadPool); + } + } + + public void testInfoWriteToWritesFixedForResizableOnOldVersion() throws IOException { + ThreadPool.Info info = new ThreadPool.Info("foo", ThreadPool.ThreadPoolType.RESIZABLE, 1); + BytesStreamOutput out = new BytesStreamOutput(); + // Use an explicit older version < 3.0.0 + out.setVersion(Version.fromString("2.9.0")); + info.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.fromString("2.9.0")); + in.readString(); // name + String typeStr = in.readString(); + assertEquals("fixed", typeStr); + } + + public void testStatsAndValidateSettingForForkJoinPool() { + // Register a ForkJoinPool-based executor in ThreadPool + Settings settings = Settings.builder().put("node.name", "testnode").build(); + int parallelism = 3; + ThreadPool threadPool = new ThreadPool(settings, new ForkJoinPoolExecutorBuilder("jvector", parallelism)); + try { + // --- Cover stats() branch for FORK_JOIN --- + ThreadPoolStats stats = threadPool.stats(); + boolean found = false; + for (ThreadPoolStats.Stats stat : stats) { + if ("jvector".equals(stat.getName())) { + found = true; + + assertEquals(0, stat.getThreads()); + assertEquals(0, stat.getQueue()); + assertEquals(0, stat.getActive()); + assertEquals(0, stat.getRejected()); + assertEquals(0, stat.getLargest()); + assertEquals(0, stat.getCompleted()); + assertEquals(-1, stat.getWaitTimeNanos()); + assertEquals(parallelism, stat.getParallelism()); + } + } + assertTrue("ForkJoinPool stats entry should exist", found); + + // --- Cover validateSetting skip/continue for FORK_JOIN --- + // We intentionally supply a bogus config for jvector. Should hit the continue branch and NOT throw. + Settings bogus = Settings.builder().put("jvector.size", "99").build(); + threadPool.setThreadPool(bogus); // Should not throw! + + // Also cover the branch in validateSetting that throws for unknown thread pool name + Settings unknown = Settings.builder().put("notarealthreadpool.size", 1).build(); + Exception e = expectThrows(IllegalArgumentException.class, () -> threadPool.setThreadPool(unknown)); + assertTrue(e.getMessage().contains("illegal thread_pool name")); + } finally { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + + public void testInfoStreamInputThrowsOnUnknownTypeAndNewVersion() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.CURRENT); + out.writeString("foo"); + out.writeString("unknown_type"); + out.writeInt(1); + out.writeInt(1); + out.writeOptionalTimeValue(null); + out.writeOptionalWriteable(null); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.CURRENT); + Exception e = expectThrows(IllegalArgumentException.class, () -> new ThreadPool.Info(in)); + assertTrue(e.getMessage().contains("Unknown ThreadPoolType")); + } + + public void testStatsSerializationParallelismNegativeValue() throws IOException { + ThreadPoolStats.Stats statsOut = new ThreadPoolStats.Stats("test", 1, 2, 3, 4L, 5, 6L, 7L, -1); + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_3_4_0); + statsOut.writeTo(out); + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_3_4_0); + ThreadPoolStats.Stats statsIn = new ThreadPoolStats.Stats(in); + assertEquals(-1, statsIn.getParallelism()); + } }