Skip to content

Commit

Permalink
Makes MockDiskCache actually use serializers
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Alfonsi <[email protected]>
  • Loading branch information
Peter Alfonsi committed Dec 6, 2024
1 parent f4d3504 commit c0c48d9
Showing 1 changed file with 102 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.cache.RemovalListener;
import org.opensearch.common.cache.RemovalNotification;
import org.opensearch.common.cache.RemovalReason;
import org.opensearch.common.cache.serializer.ICacheKeySerializer;
import org.opensearch.common.cache.serializer.Serializer;
import org.opensearch.common.cache.stats.CacheStatsHolder;
import org.opensearch.common.cache.stats.DefaultCacheStatsHolder;
Expand All @@ -32,29 +33,34 @@

public class MockDiskCache<K, V> implements ICache<K, V> {

Map<ICacheKey<K>, V> cache;
Map<ByteArrayWrapper, ByteArrayWrapper> cache;
int maxSize;
long delay;

private final RemovalListener<ICacheKey<K>, V> removalListener;
private final CacheStatsHolder statsHolder; // Only update for number of entries; this is only used to test statsTrackingEnabled logic
// in TSC

public MockDiskCache(int maxSize, long delay, RemovalListener<ICacheKey<K>, V> removalListener, boolean statsTrackingEnabled) {
private final Serializer<ICacheKey<K>, byte[]> keySerializer;
private final Serializer<V, byte[]> valueSerializer;

public MockDiskCache(int maxSize, long delay, RemovalListener<ICacheKey<K>, V> removalListener, boolean statsTrackingEnabled, Serializer<K, byte[]> keySerializer, Serializer<V, byte[]> valueSerializer) {
this.maxSize = maxSize;
this.delay = delay;
this.removalListener = removalListener;
this.cache = new ConcurrentHashMap<ICacheKey<K>, V>();
this.cache = new ConcurrentHashMap<ByteArrayWrapper, ByteArrayWrapper>();
if (statsTrackingEnabled) {
this.statsHolder = new DefaultCacheStatsHolder(List.of(), "mock_disk_cache");
} else {
this.statsHolder = NoopCacheStatsHolder.getInstance();
}
this.keySerializer = new ICacheKeySerializer<>(keySerializer);
this.valueSerializer = valueSerializer;
}

@Override
public V get(ICacheKey<K> key) {
V value = cache.get(key);
V value = deserializeValue(cache.get(serializeKey(key)));
return value;
}

Expand All @@ -70,27 +76,26 @@ public void put(ICacheKey<K> key, V value) {
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
this.cache.put(key, value);
this.cache.put(serializeKey(key), serializeValue(value));
this.statsHolder.incrementItems(List.of());
}

@Override
public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V> loader) {
V value = cache.computeIfAbsent(key, key1 -> {
return deserializeValue(cache.computeIfAbsent(serializeKey(key), key1 -> {
try {
return loader.load(key);
return serializeValue(loader.load(key));
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return value;
}));
}

@Override
public void invalidate(ICacheKey<K> key) {
V value = this.cache.remove(key);
V value = deserializeValue(this.cache.remove(serializeKey(key)));
if (value != null) {
removalListener.onRemoval(new RemovalNotification<>(key, cache.get(key), RemovalReason.INVALIDATED));
removalListener.onRemoval(new RemovalNotification<>(key, deserializeValue(cache.get(serializeKey(key))), RemovalReason.INVALIDATED));
}
}

Expand All @@ -101,7 +106,7 @@ public void invalidateAll() {

@Override
public Iterable<ICacheKey<K>> keys() {
return () -> new CacheKeyIterator<>(cache, removalListener);
return () -> new CacheKeyIterator<>(cache, removalListener, keySerializer, valueSerializer);
}

@Override
Expand Down Expand Up @@ -129,6 +134,28 @@ public void close() {

}

private ByteArrayWrapper serializeKey(ICacheKey<K> key) {
return new ByteArrayWrapper(keySerializer.serialize(key));
}

private ICacheKey<K> deserializeKey(ByteArrayWrapper binary) {
if (binary == null) {
return null;
}
return keySerializer.deserialize(binary.value);
}

private ByteArrayWrapper serializeValue(V value) {
return new ByteArrayWrapper(valueSerializer.serialize(value));
}

private V deserializeValue(ByteArrayWrapper binary) {
if (binary == null) {
return null;
}
return valueSerializer.deserialize(binary.value);
}

public static class MockDiskCacheFactory implements Factory {

public static final String NAME = "mockDiskCache";
Expand Down Expand Up @@ -183,7 +210,7 @@ public static class Builder<K, V> extends ICacheBuilder<K, V> {

@Override
public ICache<K, V> build() {
return new MockDiskCache<K, V>(this.maxSize, this.delay, this.getRemovalListener(), getStatsTrackingEnabled());
return new MockDiskCache<K, V>(this.maxSize, this.delay, this.getRemovalListener(), getStatsTrackingEnabled(), keySerializer, valueSerializer);
}

public Builder<K, V> setMaxSize(int maxSize) {
Expand Down Expand Up @@ -213,16 +240,21 @@ public Builder<K, V> setValueSerializer(Serializer<V, byte[]> valueSerializer) {
* @param <K> Type of key
* @param <V> Type of value
*/
static class CacheKeyIterator<K, V> implements Iterator<K> {
private final Iterator<Map.Entry<K, V>> entryIterator;
private final Map<K, V> cache;
private final RemovalListener<K, V> removalListener;
private K currentKey;
static class CacheKeyIterator<K, V> implements Iterator<ICacheKey<K>> {
private final Iterator<Map.Entry<ByteArrayWrapper, ByteArrayWrapper>> entryIterator;
private final Map<ByteArrayWrapper, ByteArrayWrapper> cache;
private final RemovalListener<ICacheKey<K>, V> removalListener;
private ICacheKey<K> currentKey;
private final Serializer<ICacheKey<K>, byte[]> keySerializer;
private final Serializer<V, byte[]> valueSerializer;

public CacheKeyIterator(Map<K, V> cache, RemovalListener<K, V> removalListener) {

public CacheKeyIterator(Map<ByteArrayWrapper, ByteArrayWrapper> cache, RemovalListener<ICacheKey<K>, V> removalListener, Serializer<ICacheKey<K>, byte[]> keySerializer, Serializer<V, byte[]> valueSerializer) {
this.entryIterator = cache.entrySet().iterator();
this.removalListener = removalListener;
this.cache = cache;
this.keySerializer = keySerializer;
this.valueSerializer = valueSerializer;
}

@Override
Expand All @@ -231,12 +263,12 @@ public boolean hasNext() {
}

@Override
public K next() {
public ICacheKey<K> next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
Map.Entry<K, V> entry = entryIterator.next();
currentKey = entry.getKey();
Map.Entry<ByteArrayWrapper, ByteArrayWrapper> entry = entryIterator.next();
currentKey = deserializeKey(entry.getKey());
return currentKey;
}

Expand All @@ -245,10 +277,56 @@ public void remove() {
if (currentKey == null) {
throw new IllegalStateException("No element to remove");
}
V value = cache.get(currentKey);
cache.remove(currentKey);
V value = deserializeValue(cache.get(serializeKey(currentKey)));
cache.remove(serializeKey(currentKey));
this.removalListener.onRemoval(new RemovalNotification<>(currentKey, value, RemovalReason.INVALIDATED));
currentKey = null;
}

// TODO: Just duplicated these - sad!
private ByteArrayWrapper serializeKey(ICacheKey<K> key) {
return new ByteArrayWrapper(keySerializer.serialize(key));
}

private ICacheKey<K> deserializeKey(ByteArrayWrapper binary) {
if (binary == null) {
return null;
}
return keySerializer.deserialize(binary.value);
}

private ByteArrayWrapper serializeValue(V value) {
return new ByteArrayWrapper(valueSerializer.serialize(value));
}

private V deserializeValue(ByteArrayWrapper binary) {
if (binary == null) {
return null;
}
return valueSerializer.deserialize(binary.value);
}
}

// Duplicated from EhcacheDiskCache
static class ByteArrayWrapper {
private final byte[] value;

public ByteArrayWrapper(byte[] value) {
this.value = value;
}

@Override
public boolean equals(Object o) {
if (o == null || o.getClass() != ByteArrayWrapper.class) {
return false;
}
ByteArrayWrapper other = (ByteArrayWrapper) o;
return Arrays.equals(this.value, other.value);
}

@Override
public int hashCode() {
return Arrays.hashCode(value);
}
}
}

0 comments on commit c0c48d9

Please sign in to comment.