diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a8d7974..d150ca13b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,5 +26,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation ### Maintenance * Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236) +* Added periodic cache maintenance for QuantizationStateCache and NativeMemoryCache [#2239](https://github.com/opensearch-project/k-NN/issues/2239) * Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] ### Refactoring diff --git a/src/main/java/org/opensearch/knn/index/CacheMaintainer.java b/src/main/java/org/opensearch/knn/index/CacheMaintainer.java new file mode 100644 index 000000000..c4df39d35 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/CacheMaintainer.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.google.common.cache.Cache; + +import java.io.Closeable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Performs periodic maintenance for a Guava cache. The Guava cache is implemented in a way that maintenance operations (such as evicting expired + * entries) will only occur when the cache is accessed. See {@see Guava Cache Guide} + * for more details. Thus, to perform any pending maintenance, the cleanUp method will be called periodically from a CacheMaintainer instance. + */ +public class CacheMaintainer implements Closeable { + private final Cache cache; + private final ScheduledExecutorService executor; + private static final int DEFAULT_INTERVAL_SECONDS = 60; + + public CacheMaintainer(Cache cache) { + this.cache = cache; + this.executor = Executors.newSingleThreadScheduledExecutor(); + } + + public void startMaintenance() { + executor.scheduleAtFixedRate(this::cleanCache, DEFAULT_INTERVAL_SECONDS, DEFAULT_INTERVAL_SECONDS, TimeUnit.SECONDS); + } + + public void cleanCache() { + cache.cleanUp(); + } + + @Override + public void close() { + executor.shutdown(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index b8aecc5a5..e20c81442 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -22,6 +22,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.CacheMaintainer; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.plugin.stats.StatNames; @@ -51,6 +52,7 @@ public class NativeMemoryCacheManager implements Closeable { private Cache cache; private Deque accessRecencyQueue; private final ExecutorService executor; + private CacheMaintainer cacheMaintainer; private AtomicBoolean cacheCapacityReached; private long maxWeight; @@ -87,6 +89,10 @@ private void initialize() { } private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { + if (cacheMaintainer != null) { + cacheMaintainer.close(); + } + CacheBuilder cacheBuilder = CacheBuilder.newBuilder() .recordStats() .concurrencyLevel(1) @@ -104,6 +110,9 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { cacheCapacityReached = new AtomicBoolean(false); accessRecencyQueue = new ConcurrentLinkedDeque<>(); cache = cacheBuilder.build(); + + this.cacheMaintainer = new CacheMaintainer<>(cache); + this.cacheMaintainer.startMaintenance(); } /** @@ -142,6 +151,9 @@ public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCa @Override public void close() { executor.shutdown(); + if (cacheMaintainer != null) { + cacheMaintainer.close(); + } } /** diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java index f057026b9..77afabacc 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -14,8 +14,10 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.CacheMaintainer; import org.opensearch.knn.index.KNNSettings; +import java.io.Closeable; import java.io.IOException; import java.time.Instant; import java.util.concurrent.TimeUnit; @@ -27,10 +29,11 @@ * A thread-safe singleton cache that contains quantization states. */ @Log4j2 -public class QuantizationStateCache { +public class QuantizationStateCache implements Closeable { private static volatile QuantizationStateCache instance; private Cache cache; + private CacheMaintainer cacheMaintainer; @Getter private long maxCacheSizeInKB; @Getter @@ -58,6 +61,10 @@ static QuantizationStateCache getInstance() { } private void buildCache() { + if (cacheMaintainer != null) { + cacheMaintainer.close(); + } + this.cache = CacheBuilder.newBuilder().concurrencyLevel(1).maximumWeight(maxCacheSizeInKB).weigher((k, v) -> { try { return ((QuantizationState) v).toByteArray().length; @@ -71,6 +78,9 @@ private void buildCache() { ) .removalListener(this::onRemoval) .build(); + + this.cacheMaintainer = new CacheMaintainer<>(cache); + this.cacheMaintainer.startMaintenance(); } synchronized void rebuildCache() { @@ -129,4 +139,9 @@ private void updateEvictedDueToSizeAt() { public void clear() { cache.invalidateAll(); } + + @Override + public void close() throws IOException { + cacheMaintainer.close(); + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 932d5cde0..193abed80 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -79,4 +79,8 @@ public void setMaxCacheSizeInKB(long maxCacheSizeInKB) { public void clear() { QuantizationStateCache.getInstance().clear(); } + + public void close() throws IOException { + QuantizationStateCache.getInstance().close(); + } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 7bfce5b94..12b3cbba6 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -35,6 +35,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.IndexService; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.plugins.Plugin; import org.opensearch.core.rest.RestStatus; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -86,6 +87,7 @@ protected boolean resetNodeAfterTest() { public void tearDown() throws Exception { NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); + QuantizationStateCacheManager.getInstance().close(); NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close(); diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 21b3298be..376692f26 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -24,8 +24,10 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.HashSet; import java.util.Map; @@ -73,7 +75,7 @@ protected boolean enableWarningsCheck() { return false; } - public void resetState() { + public void resetState() throws IOException { // Reset all of the counters for (KNNCounter knnCounter : KNNCounter.values()) { knnCounter.set(0L); @@ -83,6 +85,7 @@ public void resetState() { // Clean up the cache NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); + QuantizationStateCacheManager.getInstance().close(); } private void initKNNSettings() { diff --git a/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java b/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java new file mode 100644 index 000000000..18acc54eb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import org.junit.Test; + +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; + +public class CacheMaintainerTests { + @Test + public void testCacheEviction() throws InterruptedException { + Cache testCache = CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.SECONDS).build(); + + CacheMaintainer cleaner = new CacheMaintainer<>(testCache); + + testCache.put("key1", "value1"); + assertEquals(testCache.size(), 1); + + Thread.sleep(1500); + + cleaner.cleanCache(); + assertEquals(testCache.size(), 0); + + cleaner.close(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 5fe41c88c..2433f0265 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -41,6 +41,7 @@ public void tearDown() throws Exception { Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get(); + NativeMemoryCacheManager.getInstance().close(); super.tearDown(); } @@ -378,6 +379,7 @@ public void testCacheCapacity() { nativeMemoryCacheManager.setCacheCapacityReached(false); assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); + nativeMemoryCacheManager.close(); } public void testGetIndicesCacheStats() throws IOException, ExecutionException { diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index e5381aec7..88fe21d90 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -16,6 +16,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -417,7 +418,7 @@ public void testRebuildOnTimeExpirySettingsChange() { assertNull("State should be null", retrievedState); } - public void testCacheEvictionDueToSize() { + public void testCacheEvictionDueToSize() throws IOException { String fieldName = "evictionField"; // States have size of slightly over 500 bytes so that adding two will reach the max size of 1 kb for the cache int arrayLength = 112; @@ -445,6 +446,7 @@ public void testCacheEvictionDueToSize() { cache.addQuantizationState(fieldName, state); cache.addQuantizationState(fieldName, state2); cache.clear(); + cache.close(); assertNotNull(cache.getEvictedDueToSizeAt()); } }