Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thread to periodically perform pending cache maintenance #2308

Merged
merged 15 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
* Added periodic cache maintenance for QuantizationStateCache and NativeMemoryCache [#2308](https://github.com/opensearch-project/k-NN/pull/2308)
* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
* Added Lucene BWC tests (#2313)[https://github.com/opensearch-project/k-NN/pull/2313]
* Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.google.common.cache.CacheStats;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -24,6 +26,8 @@
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.plugin.stats.StatNames;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.Scheduler.Cancellable;

import java.io.Closeable;
import java.util.Deque;
Expand All @@ -47,12 +51,16 @@ public class NativeMemoryCacheManager implements Closeable {

private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class);
private static NativeMemoryCacheManager INSTANCE;
@Setter
private static ThreadPool threadPool;

private Cache<String, NativeMemoryAllocation> cache;
private Deque<String> accessRecencyQueue;
private final ExecutorService executor;
private AtomicBoolean cacheCapacityReached;
private long maxWeight;
@Getter
private Cancellable maintenanceTask;

NativeMemoryCacheManager() {
this.executor = Executors.newSingleThreadExecutor();
shatejas marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -104,6 +112,12 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
cacheCapacityReached = new AtomicBoolean(false);
accessRecencyQueue = new ConcurrentLinkedDeque<>();
cache = cacheBuilder.build();

if (threadPool != null) {
owenhalpert marked this conversation as resolved.
Show resolved Hide resolved
startMaintenance(cache);
} else {
logger.warn("ThreadPool is null during NativeMemoryCacheManager initialization. Maintenance will not start.");
}
}

/**
Expand Down Expand Up @@ -142,6 +156,9 @@ public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCa
@Override
public void close() {
executor.shutdown();
if (maintenanceTask != null) {
maintenanceTask.cancel();
}
}

/**
Expand Down Expand Up @@ -449,4 +466,29 @@ private Float getSizeAsPercentage(long size) {
}
return 100 * size / (float) cbLimit;
}

/**
* Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only
* performs maintenance operations (such as evicting expired entries) when the cache is accessed. This
* ensures that the cache is also cleaned up based on the configured expiry time.
* @see <a href="https://github.com/google/guava/wiki/cachesexplained#timed-eviction"> Guava Cache Guide</a>
* @param cacheInstance cache on which to call cleanUp()
*/
private void startMaintenance(Cache<String, NativeMemoryAllocation> cacheInstance) {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}

Runnable cleanUp = () -> {
try {
cacheInstance.cleanUp();
} catch (Exception e) {
logger.error("Error cleaning up cache", e);
}
};

TimeValue interval = KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES);

maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT);
}
}
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.index.engine.EngineFactory;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -79,6 +80,7 @@
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.knn.training.VectorReader;
Expand Down Expand Up @@ -201,6 +203,8 @@ public Collection<Object> createComponents(
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
QuantizationStateCache.setThreadPool(threadPool);
NativeMemoryCacheManager.setThreadPool(threadPool);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.threadpool.Scheduler.Cancellable;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.time.Instant;
import java.util.concurrent.TimeUnit;
Expand All @@ -27,14 +31,18 @@
* A thread-safe singleton cache that contains quantization states.
*/
@Log4j2
public class QuantizationStateCache {
public class QuantizationStateCache implements Closeable {

private static volatile QuantizationStateCache instance;
@Setter
private static ThreadPool threadPool;
private Cache<String, QuantizationState> cache;
@Getter
private long maxCacheSizeInKB;
@Getter
private Instant evictedDueToSizeAt;
@Getter
private Cancellable maintenanceTask;

@VisibleForTesting
QuantizationStateCache() {
Expand Down Expand Up @@ -71,6 +79,37 @@ private void buildCache() {
)
.removalListener(this::onRemoval)
.build();

if (threadPool != null) {
startMaintenance(cache);
} else {
log.warn("ThreadPool is null during QuantizationStateCache initialization. Maintenance will not start.");
}
}

/**
* Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only
* performs maintenance operations (such as evicting expired entries) when the cache is accessed. This
* ensures that the cache is also cleaned up based on the configured expiry time.
* @see <a href="https://github.com/google/guava/wiki/cachesexplained#timed-eviction"> Guava Cache Guide</a>
* @param cacheInstance cache on which to call cleanUp()
*/
private void startMaintenance(Cache<String, QuantizationState> cacheInstance) {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}

Runnable cleanUp = () -> {
try {
cacheInstance.cleanUp();
} catch (Exception e) {
log.error("Error cleaning up cache", e);
}
};

TimeValue interval = KNNSettings.state().getSettingValue(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES);

maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT);
}

synchronized void rebuildCache() {
Expand Down Expand Up @@ -129,4 +168,12 @@ private void updateEvictedDueToSizeAt() {
public void clear() {
cache.invalidateAll();
}

@Override
public void close() throws IOException {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader;

import java.io.Closeable;
import java.io.IOException;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class QuantizationStateCacheManager {
public final class QuantizationStateCacheManager implements Closeable {

private static volatile QuantizationStateCacheManager instance;

Expand Down Expand Up @@ -79,4 +80,9 @@ public void setMaxCacheSizeInKB(long maxCacheSizeInKB) {
public void clear() {
QuantizationStateCache.getInstance().clear();
}

@Override
public void close() throws IOException {
owenhalpert marked this conversation as resolved.
Show resolved Hide resolved
QuantizationStateCache.getInstance().close();
}
}
2 changes: 2 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.IndexService;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.plugins.Plugin;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
Expand Down Expand Up @@ -86,6 +87,7 @@ protected boolean resetNodeAfterTest() {
public void tearDown() throws Exception {
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
QuantizationStateCacheManager.getInstance().close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close();
Expand Down
5 changes: 4 additions & 1 deletion src/test/java/org/opensearch/knn/KNNTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
Expand Down Expand Up @@ -73,7 +75,7 @@ protected boolean enableWarningsCheck() {
return false;
}

public void resetState() {
public void resetState() throws IOException {
// Reset all of the counters
for (KNNCounter knnCounter : KNNCounter.values()) {
knnCounter.set(0L);
Expand All @@ -83,6 +85,7 @@ public void resetState() {
// Clean up the cache
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
QuantizationStateCacheManager.getInstance().close();
}

private void initKNNSettings() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
package org.opensearch.knn.index.memory;

import com.google.common.cache.CacheStats;
import org.junit.After;
import org.junit.Before;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
Expand All @@ -20,6 +22,8 @@
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
import org.opensearch.threadpool.Scheduler.Cancellable;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Collection;
Expand All @@ -34,13 +38,29 @@

public class NativeMemoryCacheManagerTests extends OpenSearchSingleNodeTestCase {

private ThreadPool threadPool;

@Before
public void setUp() throws Exception {
super.setUp();
threadPool = new ThreadPool(Settings.builder().put("node.name", "NativeMemoryCacheManagerTests").build());
NativeMemoryCacheManager.setThreadPool(threadPool);
}

@After
public void shutdown() throws Exception {
super.tearDown();
terminate(threadPool);
}

@Override
public void tearDown() throws Exception {
// Clear out persistent metadata
ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest();
Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build();
clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings);
client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get();
NativeMemoryCacheManager.getInstance().close();
super.tearDown();
}

Expand All @@ -51,6 +71,8 @@ protected Collection<Class<? extends Plugin>> getPlugins() {

public void testRebuildCache() throws ExecutionException, InterruptedException {
NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager();
Cancellable task1 = nativeMemoryCacheManager.getMaintenanceTask();
assertNotNull(task1);

// Put entry in cache and check that the weight matches
int size = 10;
Expand All @@ -65,6 +87,9 @@ public void testRebuildCache() throws ExecutionException, InterruptedException {
// Sleep for a second or two so that the executor can invalidate all entries
Thread.sleep(2000);

assertTrue(task1.isCancelled());
assertNotNull(nativeMemoryCacheManager.getMaintenanceTask());

assertEquals(0, nativeMemoryCacheManager.getCacheSizeInKilobytes());
nativeMemoryCacheManager.close();
}
Expand Down Expand Up @@ -378,6 +403,7 @@ public void testCacheCapacity() {

nativeMemoryCacheManager.setCacheCapacityReached(false);
assertFalse(nativeMemoryCacheManager.isCacheCapacityReached());
nativeMemoryCacheManager.close();
}

public void testGetIndicesCacheStats() throws IOException, ExecutionException {
Expand Down Expand Up @@ -464,6 +490,16 @@ public void testGetIndicesCacheStats() throws IOException, ExecutionException {
nativeMemoryCacheManager.close();
}

public void testMaintenanceScheduled() {
NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager();
Cancellable maintenanceTask = nativeMemoryCacheManager.getMaintenanceTask();

assertNotNull(maintenanceTask);

nativeMemoryCacheManager.close();
assertTrue(maintenanceTask.isCancelled());
}

private static class TestNativeMemoryAllocation implements NativeMemoryAllocation {

int size;
Expand Down
Loading
Loading