Skip to content

Commit

Permalink
Refactor SafeSerializationUtils for better performance (#4973)
Browse files Browse the repository at this point in the history
Signed-off-by: shikharj05 <[email protected]>
  • Loading branch information
shikharj05 authored Dec 18, 2024
1 parent 79a3299 commit 2f870c7
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.opensearch.security.auth.UserInjector;
Expand Down Expand Up @@ -57,7 +56,7 @@ public final class SafeSerializationUtils {
LdapAttribute.class
);

private static final List<Class<?>> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of(
private static final Set<Class<?>> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableSet.of(
InetAddress.class,
Number.class,
Collection.class,
Expand All @@ -66,18 +65,28 @@ public final class SafeSerializationUtils {
);

private static final Set<String> SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues");
static final Map<Class<?>, Boolean> safeClassCache = new ConcurrentHashMap<>();

static boolean isSafeClass(Class<?> cls) {
return cls.isArray()
|| SAFE_CLASSES.contains(cls)
|| SAFE_CLASS_NAMES.contains(cls.getName())
|| SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls));
return safeClassCache.computeIfAbsent(cls, SafeSerializationUtils::computeIsSafeClass);
}

static boolean computeIsSafeClass(Class<?> cls) {
return cls.isArray() || SAFE_CLASSES.contains(cls) || SAFE_CLASS_NAMES.contains(cls.getName()) || isAssignableFromSafeClass(cls);
}

private static boolean isAssignableFromSafeClass(Class<?> cls) {
for (Class<?> safeClass : SAFE_ASSIGNABLE_FROM_CLASSES) {
if (safeClass.isAssignableFrom(cls)) {
return true;
}
}
return false;
}

static void prohibitUnsafeClasses(Class<?> clazz) throws IOException {
if (!isSafeClass(clazz)) {
throw new IOException("Unauthorized serialization attempt " + clazz.getName());
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.security.support;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Pattern;

import org.junit.Test;

import org.opensearch.security.auth.UserInjector;
import org.opensearch.security.user.User;

import com.amazon.dlic.auth.ldap.LdapUser;
import org.ldaptive.AbstractLdapBean;
import org.ldaptive.LdapAttribute;
import org.ldaptive.LdapEntry;
import org.ldaptive.SearchEntry;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class SafeSerializationUtilsTest {

@Test
public void testSafeClasses() {
assertTrue(SafeSerializationUtils.isSafeClass(String.class));
assertTrue(SafeSerializationUtils.isSafeClass(InetSocketAddress.class));
assertTrue(SafeSerializationUtils.isSafeClass(Pattern.class));
assertTrue(SafeSerializationUtils.isSafeClass(User.class));
assertTrue(SafeSerializationUtils.isSafeClass(UserInjector.InjectedUser.class));
assertTrue(SafeSerializationUtils.isSafeClass(SourceFieldsContext.class));
assertTrue(SafeSerializationUtils.isSafeClass(LdapUser.class));
assertTrue(SafeSerializationUtils.isSafeClass(SearchEntry.class));
assertTrue(SafeSerializationUtils.isSafeClass(LdapEntry.class));
assertTrue(SafeSerializationUtils.isSafeClass(AbstractLdapBean.class));
assertTrue(SafeSerializationUtils.isSafeClass(LdapAttribute.class));
}

@Test
public void testSafeAssignableClasses() {
assertTrue(SafeSerializationUtils.isSafeClass(InetAddress.class));
assertTrue(SafeSerializationUtils.isSafeClass(Integer.class));
assertTrue(SafeSerializationUtils.isSafeClass(ArrayList.class));
assertTrue(SafeSerializationUtils.isSafeClass(HashMap.class));
assertTrue(SafeSerializationUtils.isSafeClass(Enum.class));
}

@Test
public void testArraysAreSafe() {
assertTrue(SafeSerializationUtils.isSafeClass(String[].class));
assertTrue(SafeSerializationUtils.isSafeClass(int[].class));
assertTrue(SafeSerializationUtils.isSafeClass(Object[].class));
}

@Test
public void testUnsafeClasses() {
assertFalse(SafeSerializationUtils.isSafeClass(SafeSerializationUtilsTest.class));
assertFalse(SafeSerializationUtils.isSafeClass(Runtime.class));
}

@Test
public void testProhibitUnsafeClasses() {
try {
SafeSerializationUtils.prohibitUnsafeClasses(String.class);
} catch (IOException e) {
fail("Should not throw exception for safe class");
}

try {
SafeSerializationUtils.prohibitUnsafeClasses(SafeSerializationUtilsTest.class);
fail("Should throw exception for unsafe class");
} catch (IOException e) {
assertEquals("Unauthorized serialization attempt " + SafeSerializationUtilsTest.class.getName(), e.getMessage());
}
}

@Test
public void testInheritance() {
class CustomArrayList extends ArrayList<String> {}
assertTrue(SafeSerializationUtils.isSafeClass(CustomArrayList.class));

class CustomMap extends HashMap<String, Integer> {}
assertTrue(SafeSerializationUtils.isSafeClass(CustomMap.class));
}

@Test
public void testCaching() {
// First call should compute the result
boolean result1 = SafeSerializationUtils.isSafeClass(String.class);
assertTrue(result1);

// Second call should use cached result
boolean result2 = SafeSerializationUtils.isSafeClass(String.class);
assertTrue(result2);

// Verify that the cache was used (size should be 1)
assertEquals(1, SafeSerializationUtils.safeClassCache.size());

// Third call for a different class
boolean result3 = SafeSerializationUtils.isSafeClass(Integer.class);
assertTrue(result3);
// Verify that the cache was updated
assertEquals(2, SafeSerializationUtils.safeClassCache.size());
}
}

0 comments on commit 2f870c7

Please sign in to comment.