diff --git a/CHANGELOG.md b/CHANGELOG.md
index 03a8d7974..d150ca13b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -26,5 +26,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
+* Added periodic cache maintenance for QuantizationStateCache and NativeMemoryCache [#2239](https://github.com/opensearch-project/k-NN/issues/2239)
* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
### Refactoring
diff --git a/src/main/java/org/opensearch/knn/index/CacheMaintainer.java b/src/main/java/org/opensearch/knn/index/CacheMaintainer.java
new file mode 100644
index 000000000..c4df39d35
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/index/CacheMaintainer.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.index;
+
+import com.google.common.cache.Cache;
+
+import java.io.Closeable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Performs periodic maintenance for a Guava cache. The Guava cache is implemented in a way that maintenance operations (such as evicting expired
+ * entries) will only occur when the cache is accessed. See {@see Guava Cache Guide}
+ * for more details. Thus, to perform any pending maintenance, the cleanUp method will be called periodically from a CacheMaintainer instance.
+ */
+public class CacheMaintainer implements Closeable {
+ private final Cache cache;
+ private final ScheduledExecutorService executor;
+ private static final int DEFAULT_INTERVAL_SECONDS = 60;
+
+ public CacheMaintainer(Cache cache) {
+ this.cache = cache;
+ this.executor = Executors.newSingleThreadScheduledExecutor();
+ }
+
+ public void startMaintenance() {
+ executor.scheduleAtFixedRate(this::cleanCache, DEFAULT_INTERVAL_SECONDS, DEFAULT_INTERVAL_SECONDS, TimeUnit.SECONDS);
+ }
+
+ public void cleanCache() {
+ cache.cleanUp();
+ }
+
+ @Override
+ public void close() {
+ executor.shutdown();
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java
index b8aecc5a5..e20c81442 100644
--- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java
+++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java
@@ -22,6 +22,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
+import org.opensearch.knn.index.CacheMaintainer;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.plugin.stats.StatNames;
@@ -51,6 +52,7 @@ public class NativeMemoryCacheManager implements Closeable {
private Cache cache;
private Deque accessRecencyQueue;
private final ExecutorService executor;
+ private CacheMaintainer cacheMaintainer;
private AtomicBoolean cacheCapacityReached;
private long maxWeight;
@@ -87,6 +89,10 @@ private void initialize() {
}
private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
+ if (cacheMaintainer != null) {
+ cacheMaintainer.close();
+ }
+
CacheBuilder cacheBuilder = CacheBuilder.newBuilder()
.recordStats()
.concurrencyLevel(1)
@@ -104,6 +110,9 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
cacheCapacityReached = new AtomicBoolean(false);
accessRecencyQueue = new ConcurrentLinkedDeque<>();
cache = cacheBuilder.build();
+
+ this.cacheMaintainer = new CacheMaintainer<>(cache);
+ this.cacheMaintainer.startMaintenance();
}
/**
@@ -142,6 +151,9 @@ public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCa
@Override
public void close() {
executor.shutdown();
+ if (cacheMaintainer != null) {
+ cacheMaintainer.close();
+ }
}
/**
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java
index f057026b9..77afabacc 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java
@@ -14,8 +14,10 @@
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.unit.ByteSizeValue;
+import org.opensearch.knn.index.CacheMaintainer;
import org.opensearch.knn.index.KNNSettings;
+import java.io.Closeable;
import java.io.IOException;
import java.time.Instant;
import java.util.concurrent.TimeUnit;
@@ -27,10 +29,11 @@
* A thread-safe singleton cache that contains quantization states.
*/
@Log4j2
-public class QuantizationStateCache {
+public class QuantizationStateCache implements Closeable {
private static volatile QuantizationStateCache instance;
private Cache cache;
+ private CacheMaintainer cacheMaintainer;
@Getter
private long maxCacheSizeInKB;
@Getter
@@ -58,6 +61,10 @@ static QuantizationStateCache getInstance() {
}
private void buildCache() {
+ if (cacheMaintainer != null) {
+ cacheMaintainer.close();
+ }
+
this.cache = CacheBuilder.newBuilder().concurrencyLevel(1).maximumWeight(maxCacheSizeInKB).weigher((k, v) -> {
try {
return ((QuantizationState) v).toByteArray().length;
@@ -71,6 +78,9 @@ private void buildCache() {
)
.removalListener(this::onRemoval)
.build();
+
+ this.cacheMaintainer = new CacheMaintainer<>(cache);
+ this.cacheMaintainer.startMaintenance();
}
synchronized void rebuildCache() {
@@ -129,4 +139,9 @@ private void updateEvictedDueToSizeAt() {
public void clear() {
cache.invalidateAll();
}
+
+ @Override
+ public void close() throws IOException {
+ cacheMaintainer.close();
+ }
}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java
index 932d5cde0..193abed80 100644
--- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java
@@ -79,4 +79,8 @@ public void setMaxCacheSizeInKB(long maxCacheSizeInKB) {
public void clear() {
QuantizationStateCache.getInstance().clear();
}
+
+ public void close() throws IOException {
+ QuantizationStateCache.getInstance().close();
+ }
}
diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
index 7bfce5b94..12b3cbba6 100644
--- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
+++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
@@ -35,6 +35,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.IndexService;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.plugins.Plugin;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
@@ -86,6 +87,7 @@ protected boolean resetNodeAfterTest() {
public void tearDown() throws Exception {
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
+ QuantizationStateCacheManager.getInstance().close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close();
diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java
index 21b3298be..376692f26 100644
--- a/src/test/java/org/opensearch/knn/KNNTestCase.java
+++ b/src/test/java/org/opensearch/knn/KNNTestCase.java
@@ -24,8 +24,10 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentHelper;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.test.OpenSearchTestCase;
+import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
@@ -73,7 +75,7 @@ protected boolean enableWarningsCheck() {
return false;
}
- public void resetState() {
+ public void resetState() throws IOException {
// Reset all of the counters
for (KNNCounter knnCounter : KNNCounter.values()) {
knnCounter.set(0L);
@@ -83,6 +85,7 @@ public void resetState() {
// Clean up the cache
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
+ QuantizationStateCacheManager.getInstance().close();
}
private void initKNNSettings() {
diff --git a/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java b/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java
new file mode 100644
index 000000000..18acc54eb
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/index/CacheMaintainerTests.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.index;
+
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import org.junit.Test;
+
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+
+public class CacheMaintainerTests {
+ @Test
+ public void testCacheEviction() throws InterruptedException {
+ Cache testCache = CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.SECONDS).build();
+
+ CacheMaintainer cleaner = new CacheMaintainer<>(testCache);
+
+ testCache.put("key1", "value1");
+ assertEquals(testCache.size(), 1);
+
+ Thread.sleep(1500);
+
+ cleaner.cleanCache();
+ assertEquals(testCache.size(), 0);
+
+ cleaner.close();
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java
index 5fe41c88c..2433f0265 100644
--- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java
+++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java
@@ -41,6 +41,7 @@ public void tearDown() throws Exception {
Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build();
clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings);
client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get();
+ NativeMemoryCacheManager.getInstance().close();
super.tearDown();
}
@@ -378,6 +379,7 @@ public void testCacheCapacity() {
nativeMemoryCacheManager.setCacheCapacityReached(false);
assertFalse(nativeMemoryCacheManager.isCacheCapacityReached());
+ nativeMemoryCacheManager.close();
}
public void testGetIndicesCacheStats() throws IOException, ExecutionException {
diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java
index e5381aec7..88fe21d90 100644
--- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java
+++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java
@@ -16,6 +16,7 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -417,7 +418,7 @@ public void testRebuildOnTimeExpirySettingsChange() {
assertNull("State should be null", retrievedState);
}
- public void testCacheEvictionDueToSize() {
+ public void testCacheEvictionDueToSize() throws IOException {
String fieldName = "evictionField";
// States have size of slightly over 500 bytes so that adding two will reach the max size of 1 kb for the cache
int arrayLength = 112;
@@ -445,6 +446,7 @@ public void testCacheEvictionDueToSize() {
cache.addQuantizationState(fieldName, state);
cache.addQuantizationState(fieldName, state2);
cache.clear();
+ cache.close();
assertNotNull(cache.getEvictedDueToSizeAt());
}
}