diff --git a/pom.xml b/pom.xml index 5b9cf55500..add6b378ea 100644 --- a/pom.xml +++ b/pom.xml @@ -75,6 +75,12 @@ 2.11.0 + + redis.clients.authentication + redis-authx-core + 0.1.1-beta1 + + @@ -150,6 +156,13 @@ test + + redis.clients.authentication + redis-authx-entraid + 0.1.1-beta1 + test + + io.github.resilience4j diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 2860866c6e..de473d0b8e 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -14,12 +14,14 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import java.util.concurrent.atomic.AtomicReference; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.Protocol.Keyword; import redis.clients.jedis.annots.Experimental; import redis.clients.jedis.args.ClientAttributeOption; import redis.clients.jedis.args.Rawable; +import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.commands.ProtocolCommand; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisDataException; @@ -44,6 +46,8 @@ public class Connection implements Closeable { private String strVal; protected String server; protected String version; + private AtomicReference currentCredentials = new AtomicReference<>(null); + private AuthXManager authXManager; public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -63,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC public Connection(final JedisSocketFactory socketFactory) { this.socketFactory = socketFactory; + this.authXManager = null; } public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) { @@ -93,8 +98,8 @@ public String toIdentityString() { SocketAddress remoteAddr = socket.getRemoteSocketAddress(); SocketAddress localAddr = socket.getLocalSocketAddress(); if (remoteAddr != null) { - strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, - localAddr, (broken ? '!' : '-'), remoteAddr); + strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, localAddr, + (broken ? '!' : '-'), remoteAddr); } else if (localAddr != null) { strVal = String.format("%s{id: 0x%X, L:%s}", className, id, localAddr); } else { @@ -438,8 +443,8 @@ private static boolean validateClientInfo(String info) { for (int i = 0; i < info.length(); i++) { char c = info.charAt(i); if (c < '!' || c > '~') { - throw new JedisValidationException("client info cannot contain spaces, " - + "newlines or special characters."); + throw new JedisValidationException( + "client info cannot contain spaces, " + "newlines or special characters."); } } return true; @@ -451,7 +456,13 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { protocol = config.getRedisProtocol(); - final Supplier credentialsProvider = config.getCredentialsProvider(); + Supplier credentialsProvider = config.getCredentialsProvider(); + + authXManager = config.getAuthXManager(); + if (authXManager != null) { + credentialsProvider = authXManager; + } + if (credentialsProvider instanceof RedisCredentialsProvider) { final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider; try { @@ -469,7 +480,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { String clientName = config.getClientName(); if (clientName != null && validateClientInfo(clientName)) { - fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName)); + fireAndForgetMsg + .add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName)); } ClientSetInfoConfig setInfoConfig = config.getClientSetInfoConfig(); @@ -525,12 +537,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c if (protocol != null && credentials != null && credentials.getUser() != null) { byte[] rawPass = encodeToBytes(credentials.getPassword()); try { - helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass); + helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), + encode(credentials.getUser()), rawPass); } finally { Arrays.fill(rawPass, (byte) 0); // clear sensitive data } } else { - auth(credentials); + authenticate(credentials); helloResult = protocol == null ? null : hello(encode(protocol.version())); } if (helloResult != null) { @@ -542,9 +555,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c // handled in RedisCredentialsProvider.cleanUp() } - private void auth(RedisCredentials credentials) { + public void setCredentials(RedisCredentials credentials) { + currentCredentials.set(credentials); + } + + private String authenticate(RedisCredentials credentials) { if (credentials == null || credentials.getPassword() == null) { - return; + return null; } byte[] rawPass = encodeToBytes(credentials.getPassword()); try { @@ -556,7 +573,11 @@ private void auth(RedisCredentials credentials) { } finally { Arrays.fill(rawPass, (byte) 0); // clear sensitive data } - getStatusCodeReply(); + return getStatusCodeReply(); + } + + public String reAuthenticate() { + return authenticate(currentCredentials.getAndSet(null)); } protected Map hello(byte[]... args) { @@ -585,4 +606,12 @@ public boolean ping() { } return true; } + + protected boolean isTokenBasedAuthenticationEnabled() { + return authXManager != null; + } + + protected AuthXManager getAuthXManager() { + return authXManager; + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java index cc53df56f0..7440417152 100644 --- a/src/main/java/redis/clients/jedis/ConnectionFactory.java +++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java @@ -6,7 +6,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.function.Supplier; + import redis.clients.jedis.annots.Experimental; +import redis.clients.jedis.authentication.AuthXManager; +import redis.clients.jedis.authentication.JedisAuthenticationException; +import redis.clients.jedis.authentication.AuthXEventListener; import redis.clients.jedis.csc.Cache; import redis.clients.jedis.csc.CacheConnection; import redis.clients.jedis.exceptions.JedisException; @@ -20,28 +25,52 @@ public class ConnectionFactory implements PooledObjectFactory { private final JedisSocketFactory jedisSocketFactory; private final JedisClientConfig clientConfig; - private Cache clientSideCache = null; + private final Cache clientSideCache; + private final Supplier objectMaker; + + private final AuthXEventListener authXEventListener; public ConnectionFactory(final HostAndPort hostAndPort) { - this.clientConfig = DefaultJedisClientConfig.builder().build(); - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort); + this(hostAndPort, DefaultJedisClientConfig.builder().build(), null); } public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) { - this.clientConfig = clientConfig; - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig); + this(hostAndPort, clientConfig, null); } @Experimental - public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, Cache csCache) { - this.clientConfig = clientConfig; - this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig); - this.clientSideCache = csCache; + public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, + Cache csCache) { + this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache); } - public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) { - this.clientConfig = clientConfig; + public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, + final JedisClientConfig clientConfig) { + this(jedisSocketFactory, clientConfig, null); + } + + private ConnectionFactory(final JedisSocketFactory jedisSocketFactory, + final JedisClientConfig clientConfig, Cache csCache) { + this.jedisSocketFactory = jedisSocketFactory; + this.clientSideCache = csCache; + this.clientConfig = clientConfig; + + AuthXManager authXManager = clientConfig.getAuthXManager(); + if (authXManager == null) { + this.objectMaker = connectionSupplier(); + this.authXEventListener = AuthXEventListener.NOOP_LISTENER; + } else { + Supplier supplier = connectionSupplier(); + this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get()); + this.authXEventListener = authXManager.getListener(); + authXManager.start(); + } + } + + private Supplier connectionSupplier() { + return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig) + : () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); } @Override @@ -64,8 +93,7 @@ public void destroyObject(PooledObject pooledConnection) throws Exce @Override public PooledObject makeObject() throws Exception { try { - Connection jedis = clientSideCache == null ? new Connection(jedisSocketFactory, clientConfig) - : new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache); + Connection jedis = objectMaker.get(); return new DefaultPooledObject<>(jedis); } catch (JedisException je) { logger.debug("Error while makeObject", je); @@ -76,6 +104,8 @@ public PooledObject makeObject() throws Exception { @Override public void passivateObject(PooledObject pooledConnection) throws Exception { // TODO maybe should select db 0? Not sure right now. + Connection jedis = pooledConnection.getObject(); + reAuthenticate(jedis); } @Override @@ -83,10 +113,31 @@ public boolean validateObject(PooledObject pooledConnection) { final Connection jedis = pooledConnection.getObject(); try { // check HostAndPort ?? - return jedis.isConnected() && jedis.ping(); + if (!jedis.isConnected()) { + return false; + } + reAuthenticate(jedis); + return jedis.ping(); } catch (final Exception e) { logger.warn("Error while validating pooled Connection object.", e); return false; } } + + private void reAuthenticate(Connection jedis) throws Exception { + try { + String result = jedis.reAuthenticate(); + if (result != null && !result.equals("OK")) { + String msg = "Re-authentication failed with server response: " + result; + Exception failedAuth = new JedisAuthenticationException(msg); + logger.error(failedAuth.getMessage(), failedAuth); + authXEventListener.onConnectionAuthenticationError(failedAuth); + return; + } + } catch (Exception e) { + logger.error("Error while re-authenticating connection", e); + authXEventListener.onConnectionAuthenticationError(e); + throw e; + } + } } diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java index 40d4861f98..2ae1401081 100644 --- a/src/main/java/redis/clients/jedis/ConnectionPool.java +++ b/src/main/java/redis/clients/jedis/ConnectionPool.java @@ -2,19 +2,27 @@ import org.apache.commons.pool2.PooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; + import redis.clients.jedis.annots.Experimental; +import redis.clients.jedis.authentication.AuthXManager; import redis.clients.jedis.csc.Cache; +import redis.clients.jedis.exceptions.JedisException; import redis.clients.jedis.util.Pool; public class ConnectionPool extends Pool { + private AuthXManager authXManager; + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) { this(new ConnectionFactory(hostAndPort, clientConfig)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) { + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache)); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory) { @@ -24,12 +32,14 @@ public ConnectionPool(PooledObjectFactory factory) { public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, GenericObjectPoolConfig poolConfig) { this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } @Experimental - public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache, - GenericObjectPoolConfig poolConfig) { + public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, + Cache clientSideCache, GenericObjectPoolConfig poolConfig) { this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig); + attachAuthenticationListener(clientConfig.getAuthXManager()); } public ConnectionPool(PooledObjectFactory factory, @@ -43,4 +53,29 @@ public Connection getResource() { conn.setHandlingPool(this); return conn; } + + @Override + public void close() { + try { + if (authXManager != null) { + authXManager.stop(); + } + } finally { + super.close(); + } + } + + private void attachAuthenticationListener(AuthXManager authXManager) { + this.authXManager = authXManager; + if (authXManager != null) { + authXManager.addPostAuthenticationHook(token -> { + try { + // this is to trigger validations on each connection via ConnectionFactory + evict(); + } catch (Exception e) { + throw new JedisException("Failed to evict connections from pool", e); + } + }); + } + } } diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java index d5468f3a46..7d41d9d28a 100644 --- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java @@ -5,6 +5,8 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; +import redis.clients.jedis.authentication.AuthXManager; + public final class DefaultJedisClientConfig implements JedisClientConfig { private final RedisProtocol redisProtocol; @@ -29,6 +31,8 @@ public final class DefaultJedisClientConfig implements JedisClientConfig { private final boolean readOnlyForRedisClusterReplicas; + private final AuthXManager authXManager; + private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) { this.redisProtocol = builder.redisProtocol; this.connectionTimeoutMillis = builder.connectionTimeoutMillis; @@ -45,6 +49,7 @@ private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) { this.hostAndPortMapper = builder.hostAndPortMapper; this.clientSetInfoConfig = builder.clientSetInfoConfig; this.readOnlyForRedisClusterReplicas = builder.readOnlyForRedisClusterReplicas; + this.authXManager = builder.authXManager; } @Override @@ -83,6 +88,11 @@ public Supplier getCredentialsProvider() { return credentialsProvider; } + @Override + public AuthXManager getAuthXManager() { + return authXManager; + } + @Override public int getDatabase() { return database; @@ -163,6 +173,8 @@ public static class Builder { private boolean readOnlyForRedisClusterReplicas = false; + private AuthXManager authXManager = null; + private Builder() { } @@ -279,6 +291,30 @@ public Builder readOnlyForRedisClusterReplicas() { this.readOnlyForRedisClusterReplicas = true; return this; } + + public Builder authXManager(AuthXManager authXManager) { + this.authXManager = authXManager; + return this; + } + + public Builder from(JedisClientConfig instance) { + this.redisProtocol = instance.getRedisProtocol(); + this.connectionTimeoutMillis = instance.getConnectionTimeoutMillis(); + this.socketTimeoutMillis = instance.getSocketTimeoutMillis(); + this.blockingSocketTimeoutMillis = instance.getBlockingSocketTimeoutMillis(); + this.credentialsProvider = instance.getCredentialsProvider(); + this.database = instance.getDatabase(); + this.clientName = instance.getClientName(); + this.ssl = instance.isSsl(); + this.sslSocketFactory = instance.getSslSocketFactory(); + this.sslParameters = instance.getSslParameters(); + this.hostnameVerifier = instance.getHostnameVerifier(); + this.hostAndPortMapper = instance.getHostAndPortMapper(); + this.clientSetInfoConfig = instance.getClientSetInfoConfig(); + this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas(); + this.authXManager = instance.getAuthXManager(); + return this; + } } public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int soTimeoutMillis, @@ -328,6 +364,9 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) { if (copy.isReadOnlyForRedisClusterReplicas()) { builder.readOnlyForRedisClusterReplicas(); } + + builder.authXManager(copy.getAuthXManager()); + return builder.build(); } } diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java index 8bd18b5aaa..ce7fd82de4 100644 --- a/src/main/java/redis/clients/jedis/JedisClientConfig.java +++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java @@ -5,6 +5,8 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; +import redis.clients.jedis.authentication.AuthXManager; + public interface JedisClientConfig { default RedisProtocol getRedisProtocol() { @@ -50,6 +52,10 @@ default Supplier getCredentialsProvider() { new DefaultRedisCredentials(getUser(), getPassword())); } + default AuthXManager getAuthXManager() { + return null; + } + default int getDatabase() { return Protocol.DEFAULT_DATABASE; } diff --git a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java index ec63c5206a..9462527c0f 100644 --- a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java +++ b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java @@ -103,6 +103,9 @@ public JedisClusterInfoCache(final JedisClientConfig clientConfig, Cache clientS this.clientConfig = clientConfig; this.clientSideCache = clientSideCache; this.startNodes = startNodes; + if (clientConfig.getAuthXManager() != null) { + clientConfig.getAuthXManager().start(); + } if (topologyRefreshPeriod != null) { logger.info("Cluster topology refresh start, period: {}, startNodes: {}", topologyRefreshPeriod, startNodes); topologyRefreshExecutor = Executors.newSingleThreadScheduledExecutor(); diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index bf9d0a32c5..91fee36c58 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.List; +import java.util.function.Consumer; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.exceptions.JedisException; @@ -12,7 +13,8 @@ public abstract class JedisPubSubBase { private int subscribedChannels = 0; - private volatile Connection client; + private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator(); + private final Consumer pingResultHandler = this::processPingReply; public void onMessage(T channel, T message) { } @@ -36,12 +38,7 @@ public void onPong(T pattern) { } private void sendAndFlushCommand(Command command, T... args) { - if (client == null) { - throw new JedisException(getClass() + " is not connected to a Connection."); - } - CommandArguments cargs = new CommandArguments(command).addObjects(args); - client.sendCommand(cargs); - client.flush(); + authenticator.sendAndFlushCommand(command, args); } public final void unsubscribe() { @@ -53,13 +50,23 @@ public final void unsubscribe(T... channels) { } public final void subscribe(T... channels) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.SUBSCRIBE, channels); } public final void psubscribe(T... patterns) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.PSUBSCRIBE, patterns); } + private void checkConnectionSuitableForPubSub() { + if (authenticator.client.protocol != RedisProtocol.RESP3 + && authenticator.client.isTokenBasedAuthenticationEnabled()) { + throw new JedisException( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!"); + } + } + public final void punsubscribe() { sendAndFlushCommand(Command.PUNSUBSCRIBE); } @@ -69,11 +76,23 @@ public final void punsubscribe(T... patterns) { } public final void ping() { - sendAndFlushCommand(Command.PING); + authenticator.commandSync.lock(); + try { + sendAndFlushCommand(Command.PING); + authenticator.resultHandler.add(pingResultHandler); + } finally { + authenticator.commandSync.unlock(); + } } public final void ping(T argument) { - sendAndFlushCommand(Command.PING, argument); + authenticator.commandSync.lock(); + try { + sendAndFlushCommand(Command.PING, argument); + authenticator.resultHandler.add(pingResultHandler); + } finally { + authenticator.commandSync.unlock(); + } } public final boolean isSubscribed() { @@ -85,34 +104,34 @@ public final int getSubscribedChannels() { } public final void proceed(Connection client, T... channels) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { subscribe(channels); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } public final void proceedWithPatterns(Connection client, T... patterns) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { psubscribe(patterns); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } protected abstract T encode(byte[] raw); -// private void process(Client client) { + // private void process(Client client) { private void process() { do { - Object reply = client.getUnflushedObject(); + Object reply = authenticator.client.getUnflushedObject(); if (reply instanceof List) { List listReply = (List) reply; @@ -166,18 +185,26 @@ private void process() { throw new JedisException("Unknown message type: " + firstObj); } } else if (reply instanceof byte[]) { - byte[] resp = (byte[]) reply; - if ("PONG".equals(SafeEncoder.encode(resp))) { - onPong(null); - } else { - onPong(encode(resp)); + Consumer resultHandler = authenticator.resultHandler.poll(); + if (resultHandler == null) { + throw new JedisException("Unexpected message : " + SafeEncoder.encode((byte[]) reply)); } + resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); } } while (!Thread.currentThread().isInterrupted() && isSubscribed()); -// /* Invalidate instance since this thread is no longer listening */ -// this.client = null; + // /* Invalidate instance since this thread is no longer listening */ + // this.client = null; + } + + private void processPingReply(Object reply) { + byte[] resp = (byte[]) reply; + if ("PONG".equals(SafeEncoder.encode(resp))) { + onPong(null); + } else { + onPong(encode(resp)); + } } } diff --git a/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java new file mode 100644 index 0000000000..9c7f95dba1 --- /dev/null +++ b/src/main/java/redis/clients/jedis/JedisSafeAuthenticator.java @@ -0,0 +1,104 @@ +package redis.clients.jedis; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.jedis.Protocol.Command; +import redis.clients.jedis.authentication.JedisAuthenticationException; +import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.SafeEncoder; + +class JedisSafeAuthenticator { + + private static final Token PLACEHOLDER_TOKEN = new SimpleToken(null, null, 0, 0, null); + private static final Logger logger = LoggerFactory.getLogger(JedisSafeAuthenticator.class); + + protected volatile Connection client; + protected final Consumer authResultHandler = this::processAuthReply; + protected final Consumer authenticationHandler = this::safeReAuthenticate; + + protected final AtomicReference pendingTokenRef = new AtomicReference(null); + protected final ReentrantLock commandSync = new ReentrantLock(); + protected final Queue> resultHandler = new ConcurrentLinkedQueue>(); + + protected void sendAndFlushCommand(Command command, Object... args) { + if (client == null) { + throw new JedisException(getClass() + " is not connected to a Connection."); + } + CommandArguments cargs = new CommandArguments(command).addObjects(args); + + Token newToken = pendingTokenRef.getAndSet(PLACEHOLDER_TOKEN); + + // lets send the command without locking !!IF!! we know that pendingTokenRef is null replaced with PLACEHOLDER_TOKEN and no re-auth will go into action + // !!ELSE!! we are locking since we already know a re-auth is still in progress in another thread and we need to wait for it to complete, we do nothing but wait on it! + if (newToken != null) { + commandSync.lock(); + } + try { + client.sendCommand(cargs); + client.flush(); + } finally { + Token newerToken = pendingTokenRef.getAndSet(null); + // lets check if a newer token received since the beginning of this sendAndFlushCommand call + if (newerToken != null && newerToken != PLACEHOLDER_TOKEN) { + safeReAuthenticate(newerToken); + } + if (newToken != null) { + commandSync.unlock(); + } + } + } + + protected void registerForAuthentication(Connection newClient) { + Connection oldClient = this.client; + if (oldClient == newClient) return; + if (oldClient != null && oldClient.getAuthXManager() != null) { + oldClient.getAuthXManager().removePostAuthenticationHook(authenticationHandler); + } + if (newClient != null && newClient.getAuthXManager() != null) { + newClient.getAuthXManager().addPostAuthenticationHook(authenticationHandler); + } + this.client = newClient; + } + + private void safeReAuthenticate(Token token) { + try { + byte[] rawPass = client.encodeToBytes(token.getValue().toCharArray()); + byte[] rawUser = client.encodeToBytes(token.getUser().toCharArray()); + + Token newToken = pendingTokenRef.getAndSet(token); + if (newToken == null) { + commandSync.lock(); + try { + sendAndFlushCommand(Command.AUTH, rawUser, rawPass); + resultHandler.add(this.authResultHandler); + } finally { + pendingTokenRef.set(null); + commandSync.unlock(); + } + } + } catch (Exception e) { + logger.error("Error while re-authenticating connection", e); + client.getAuthXManager().getListener().onConnectionAuthenticationError(e); + } + } + + protected void processAuthReply(Object reply) { + byte[] resp = (byte[]) reply; + String response = SafeEncoder.encode(resp); + if (!"OK".equals(response)) { + String msg = "Re-authentication failed with server response: " + response; + Exception failedAuth = new JedisAuthenticationException(msg); + logger.error(failedAuth.getMessage(), failedAuth); + client.getAuthXManager().getListener().onConnectionAuthenticationError(failedAuth); + } + } +} diff --git a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java index 2b2ce944fe..9020693929 100644 --- a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java @@ -4,14 +4,16 @@ import java.util.Arrays; import java.util.List; +import java.util.function.Consumer; import redis.clients.jedis.Protocol.Command; import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.SafeEncoder; public abstract class JedisShardedPubSubBase { private int subscribedChannels = 0; - private volatile Connection client; + private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator(); public void onSMessage(T channel, T message) { } @@ -23,12 +25,7 @@ public void onSUnsubscribe(T channel, int subscribedChannels) { } private void sendAndFlushCommand(Command command, T... args) { - if (client == null) { - throw new JedisException(getClass() + " is not connected to a Connection."); - } - CommandArguments cargs = new CommandArguments(command).addObjects(args); - client.sendCommand(cargs); - client.flush(); + authenticator.sendAndFlushCommand(command, args); } public final void sunsubscribe() { @@ -40,9 +37,18 @@ public final void sunsubscribe(T... channels) { } public final void ssubscribe(T... channels) { + checkConnectionSuitableForPubSub(); sendAndFlushCommand(Command.SSUBSCRIBE, channels); } + private void checkConnectionSuitableForPubSub() { + if (authenticator.client.protocol != RedisProtocol.RESP3 + && authenticator.client.isTokenBasedAuthenticationEnabled()) { + throw new JedisException( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!"); + } + } + public final boolean isSubscribed() { return subscribedChannels > 0; } @@ -52,23 +58,22 @@ public final int getSubscribedChannels() { } public final void proceed(Connection client, T... channels) { - this.client = client; - this.client.setTimeoutInfinite(); + authenticator.registerForAuthentication(client); + authenticator.client.setTimeoutInfinite(); try { ssubscribe(channels); process(); } finally { - this.client.rollbackTimeout(); + authenticator.client.rollbackTimeout(); } } protected abstract T encode(byte[] raw); -// private void process(Client client) { private void process() { do { - Object reply = client.getUnflushedObject(); + Object reply = authenticator.client.getUnflushedObject(); if (reply instanceof List) { List listReply = (List) reply; @@ -96,6 +101,12 @@ private void process() { } else { throw new JedisException("Unknown message type: " + firstObj); } + } else if (reply instanceof byte[]) { + Consumer resultHandler = authenticator.resultHandler.poll(); + if (resultHandler == null) { + throw new JedisException("Unexpected message : " + SafeEncoder.encode((byte[]) reply)); + } + resultHandler.accept(reply); } else { throw new JedisException("Unknown message type: " + reply); } diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java new file mode 100644 index 0000000000..4750404157 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/AuthXEventListener.java @@ -0,0 +1,21 @@ +package redis.clients.jedis.authentication; + +public interface AuthXEventListener { + + static AuthXEventListener NOOP_LISTENER = new AuthXEventListener() { + + @Override + public void onIdentityProviderError(Exception reason) { + } + + @Override + public void onConnectionAuthenticationError(Exception reason) { + } + + }; + + public void onIdentityProviderError(Exception reason); + + public void onConnectionAuthenticationError(Exception reason); + +} diff --git a/src/main/java/redis/clients/jedis/authentication/AuthXManager.java b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java new file mode 100644 index 0000000000..eba5d8428f --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/AuthXManager.java @@ -0,0 +1,128 @@ +package redis.clients.jedis.authentication; + +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.core.TokenListener; +import redis.clients.authentication.core.TokenManager; +import redis.clients.jedis.Connection; +import redis.clients.jedis.RedisCredentials; + +public final class AuthXManager implements Supplier { + + private static final Logger log = LoggerFactory.getLogger(AuthXManager.class); + + private TokenManager tokenManager; + private List> connections = Collections + .synchronizedList(new ArrayList<>()); + private Token currentToken; + private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER; + private List> postAuthenticateHooks = new ArrayList<>(); + private AtomicReference> uniqueStarterTask = new AtomicReference<>(); + + protected AuthXManager(TokenManager tokenManager) { + this.tokenManager = tokenManager; + } + + public AuthXManager(TokenAuthConfig tokenAuthConfig) { + this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(), + tokenAuthConfig.getTokenManagerConfig())); + } + + public void start() { + Future safeStarter = safeStart(this::tokenManagerStart); + try { + safeStarter.get(); + } catch (InterruptedException | ExecutionException e) { + log.error("AuthXManager failed to start!", e); + throw new JedisAuthenticationException("AuthXManager failed to start!", + (e instanceof ExecutionException) ? e.getCause() : e); + } + } + + private Future safeStart(Runnable starter) { + if (uniqueStarterTask.compareAndSet(null, new CompletableFuture())) { + try { + starter.run(); + uniqueStarterTask.get().complete(null); + } catch (Exception e) { + uniqueStarterTask.get().completeExceptionally(e); + } + } + return uniqueStarterTask.get(); + } + + private void tokenManagerStart() { + tokenManager.start(new TokenListener() { + @Override + public void onTokenRenewed(Token token) { + currentToken = token; + authenticateConnections(token); + } + + @Override + public void onError(Exception reason) { + listener.onIdentityProviderError(reason); + } + }, true); + } + + public void authenticateConnections(Token token) { + RedisCredentials credentialsFromToken = new TokenCredentials(token); + for (WeakReference connectionRef : connections) { + Connection connection = connectionRef.get(); + if (connection != null) { + connection.setCredentials(credentialsFromToken); + } else { + connections.remove(connectionRef); + } + } + postAuthenticateHooks.forEach(hook -> hook.accept(token)); + } + + public Connection addConnection(Connection connection) { + connections.add(new WeakReference<>(connection)); + return connection; + } + + public void stop() { + tokenManager.stop(); + } + + public void setListener(AuthXEventListener listener) { + if (listener != null) { + this.listener = listener; + } + } + + public void addPostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.add(postAuthenticateHook); + } + + public void removePostAuthenticationHook(Consumer postAuthenticateHook) { + postAuthenticateHooks.remove(postAuthenticateHook); + } + + public AuthXEventListener getListener() { + return listener; + } + + @Override + public RedisCredentials get() { + return new TokenCredentials(this.currentToken); + } + +} \ No newline at end of file diff --git a/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java new file mode 100644 index 0000000000..c70ab98720 --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/JedisAuthenticationException.java @@ -0,0 +1,14 @@ +package redis.clients.jedis.authentication; + +import redis.clients.jedis.exceptions.JedisException; + +public class JedisAuthenticationException extends JedisException { + + public JedisAuthenticationException(String message) { + super(message); + } + + public JedisAuthenticationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java new file mode 100644 index 0000000000..143ee60b9d --- /dev/null +++ b/src/main/java/redis/clients/jedis/authentication/TokenCredentials.java @@ -0,0 +1,24 @@ +package redis.clients.jedis.authentication; + +import redis.clients.authentication.core.Token; +import redis.clients.jedis.RedisCredentials; + +class TokenCredentials implements RedisCredentials { + private final String user; + private final char[] password; + + public TokenCredentials(Token token) { + user = token.getUser(); + password = token.getValue().toCharArray(); + } + + @Override + public String getUser() { + return user; + } + + @Override + public char[] getPassword() { + return password; + } +} \ No newline at end of file diff --git a/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java new file mode 100644 index 0000000000..b58ee2fd21 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/EntraIDTestContext.java @@ -0,0 +1,123 @@ +package redis.clients.jedis.authentication; + +import java.io.ByteArrayInputStream; +import java.security.KeyFactory; +import java.security.PrivateKey; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashSet; +import java.util.Set; + +public class EntraIDTestContext { + private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID"; + private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY"; + private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"; + private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY"; + private static final String AZURE_CERT = "AZURE_CERT"; + private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES"; + private static final String AZURE_USER_ASSIGNED_MANAGED_ID = "AZURE_USER_ASSIGNED_MANAGED_ID"; + + private String clientId; + private String authority; + private String clientSecret; + private PrivateKey privateKey; + private X509Certificate cert; + private Set redisScopes; + private String userAssignedManagedIdentity; + + public static final EntraIDTestContext DEFAULT = new EntraIDTestContext(); + + private EntraIDTestContext() { + clientId = System.getenv(AZURE_CLIENT_ID); + authority = System.getenv(AZURE_AUTHORITY); + clientSecret = System.getenv(AZURE_CLIENT_SECRET); + userAssignedManagedIdentity = System.getenv(AZURE_USER_ASSIGNED_MANAGED_ID); + } + + public EntraIDTestContext(String clientId, String authority, String clientSecret, + PrivateKey privateKey, X509Certificate cert, Set redisScopes, + String userAssignedManagedIdentity) { + this.clientId = clientId; + this.authority = authority; + this.clientSecret = clientSecret; + this.privateKey = privateKey; + this.cert = cert; + this.redisScopes = redisScopes; + this.userAssignedManagedIdentity = userAssignedManagedIdentity; + } + + public String getClientId() { + return clientId; + } + + public String getAuthority() { + return authority; + } + + public String getClientSecret() { + return clientSecret; + } + + public PrivateKey getPrivateKey() { + if (privateKey == null) { + this.privateKey = getPrivateKey(System.getenv(AZURE_PRIVATE_KEY)); + } + return privateKey; + } + + public X509Certificate getCert() { + if (cert == null) { + this.cert = getCert(System.getenv(AZURE_CERT)); + } + return cert; + } + + public Set getRedisScopes() { + if (redisScopes == null) { + String redisScopesEnv = System.getenv(AZURE_REDIS_SCOPES); + this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";"))); + } + return redisScopes; + } + + public String getUserAssignedManagedIdentity() { + return userAssignedManagedIdentity; + } + + private PrivateKey getPrivateKey(String privateKey) { + try { + // Decode the base64 encoded key into a byte array + byte[] decodedKey = Base64.getDecoder().decode(privateKey); + + // Generate the private key from the decoded byte array using PKCS8EncodedKeySpec + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use the correct algorithm (e.g., "RSA", "EC", "DSA") + PrivateKey key = keyFactory.generatePrivate(keySpec); + return key; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private X509Certificate getCert(String cert) { + try { + // Convert the Base64 encoded string into a byte array + byte[] encoded = java.util.Base64.getDecoder().decode(cert); + + // Create a CertificateFactory for X.509 certificates + CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + + // Generate the certificate from the byte array + X509Certificate certificate = (X509Certificate) certificateFactory + .generateCertificate(new ByteArrayInputStream(encoded)); + return certificate; + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java new file mode 100644 index 0000000000..55551331ed --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDIntegrationTests.java @@ -0,0 +1,381 @@ +package redis.clients.jedis.authentication; + +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.FIVE_SECONDS; +import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.awaitility.Awaitility; +import org.awaitility.Durations; +import org.junit.BeforeClass; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import org.mockito.MockedConstruction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDIdentityProvider; +import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.authentication.entraid.ServicePrincipalInfo; +import redis.clients.jedis.Connection; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.exceptions.JedisAccessControlException; +import redis.clients.jedis.exceptions.JedisConnectionException; +import redis.clients.jedis.scenario.FaultInjectionClient; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class RedisEntraIDIntegrationTests { + private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class); + + private static EntraIDTestContext testCtx; + private static EndpointConfig endpointConfig; + private static HostAndPort hnp; + + private final FaultInjectionClient faultClient = new FaultInjectionClient(); + + @BeforeClass + public static void before() { + try { + testCtx = EntraIDTestContext.DEFAULT; + endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl"); + hnp = endpointConfig.getHostAndPort(); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + @Test + public void testJedisConfig() { + AtomicInteger counter = new AtomicInteger(0); + try (MockedConstruction mockedConstructor = mockConstruction( + EntraIDIdentityProvider.class, (mock, context) -> { + ServicePrincipalInfo info = (ServicePrincipalInfo) context.arguments().get(0); + + assertEquals(testCtx.getClientId(), info.getClientId()); + assertEquals(testCtx.getAuthority(), info.getAuthority()); + assertEquals(testCtx.getClientSecret(), info.getSecret()); + assertEquals(testCtx.getRedisScopes(), context.arguments().get(1)); + assertNotNull(mock); + doAnswer(invocation -> { + counter.incrementAndGet(); + return new SimpleToken("default", "token1", System.currentTimeMillis() + 5 * 60 * 1000, + System.currentTimeMillis(), null); + }).when(mock).requestToken(); + })) { + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .authority(testCtx.getAuthority()).clientId(testCtx.getClientId()) + .secret(testCtx.getClientSecret()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379), jedisConfig); + assertNotNull(jedis); + assertEquals(1, counter.get()); + + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withSecret_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals + @Test + public void withCertificate_azureServicePrincipalIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.2.2 + // Test that the Redis client is not blocked/interrupted during token renewal. + @Test + public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException { + // set the stage with consecutive get/set operations with unique keys which keeps running with a jedispooled instace, + // configure token manager to renew token approximately approximately every 10ms + // wait till token was renewed at least 10 times after initial token acquisition + // Additional note: Assumptions made on the time taken for token renewal and operations are based on the current implementation and may vary in future + // Assumptions: + // - TTL of token is 2 hour + // - expirationRefreshRatio is 0.000001F + // - renewal delay is 7 ms each time a token is acquired + // - each auth command takes 40 ms in total to complete(considering the cloud test environments) + // - each auth command would need to wait for an ongoing customer operation(GET/SET/DEL) to complete, which would take another 40 ms + // - each renewal happens in 40+40+7 = 87 ms + // - total number of renewals would take 87 * 10 = 870 ms + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + Consumer hook = mock(Consumer.class); + authXManager.addPostAuthenticationHook(hook); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + ExecutorService jedisExecutors = Executors.newFixedThreadPool(5); + AtomicBoolean completed = new AtomicBoolean(false); + + ExecutorService runner = Executors.newSingleThreadExecutor(); + runner.submit(() -> { + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + List> futures = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + Future future = jedisExecutors.submit(() -> { + while (!completed.get()) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + try { + task.get(); + } catch (InterruptedException | ExecutionException e) { + e.printStackTrace(); + } + } + } + }); + + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_SECONDS).untilAsserted(() -> { + verify(hook, atLeast(10)).accept(any()); + }); + + completed.set(true); + runner.shutdown(); + jedisExecutors.shutdown(); + } + + // T.3.2 + // Verify that all existing connections can be re-authenticated when a new token is received. + @Test + public void allConnectionsReauthTest() throws InterruptedException, ExecutionException { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.000001F).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + + List connections = new ArrayList<>(); + + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + long startTime = System.currentTimeMillis(); + List> futures = new ArrayList<>(); + ExecutorService executor = Executors.newFixedThreadPool(5); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + Future future = executor.submit(() -> { + for (; System.currentTimeMillis() - startTime < 2000;) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + futures.add(future); + } + for (Future task : futures) { + task.get(); + } + + connections.forEach(conn -> { + verify(conn, atLeast(1)).reAuthenticate(); + }); + executor.shutdown(); + } + } + + // T.3.3 + // Verify behavior when attempting to authenticate a single connection with an expired token. + @Test + public void connectionAuthWithExpiredTokenTest() { + IdentityProvider idp = new EntraIDIdentityProviderConfig( + new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(), + testCtx.getAuthority()), + testCtx.getRedisScopes(), 1000).getProvider(); + + IdentityProvider mockIdentityProvider = mock(IdentityProvider.class); + AtomicReference token = new AtomicReference<>(); + doAnswer(invocation -> { + if (token.get() == null) { + token.set(idp.requestToken()); + } + return token.get(); + }).when(mockIdentityProvider).requestToken(); + IdentityProviderConfig idpConfig = mock(IdentityProviderConfig.class); + when(idpConfig.getProvider()).thenReturn(mockIdentityProvider); + + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .identityProviderConfig(idpConfig).expirationRefreshRatio(0.000001F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + token.set(new SimpleToken(idp.requestToken().getUser(), "token1", + System.currentTimeMillis() - 1, System.currentTimeMillis(), null)); + + JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class, + () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("WRONGPASS invalid username-password pair", aclException.getMessage()); + } + } + + // T.3.4 + // Verify handling of reconnection and re-authentication after a network partition. (use cached token) + // @Test + public void networkPartitionEvictionTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .clientId(testCtx.getClientId()).secret(testCtx.getClientSecret()) + .authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()) + .expirationRefreshRatio(0.5F).build(); + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder() + .authXManager(authXManager).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) { + for (int i = 0; i < 5; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + + triggerNetworkFailure(); + + JedisConnectionException aclException = assertThrows(JedisConnectionException.class, () -> { + for (int i = 0; i < 50; i++) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + }); + + assertEquals("Unexpected end of stream.", aclException.getMessage()); + Awaitility.await().pollDelay(Durations.ONE_HUNDRED_MILLISECONDS).atMost(Durations.TWO_SECONDS) + .until(() -> { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + return true; + }); + } + } + + private void triggerNetworkFailure() { + HashMap params = new HashMap<>(); + params.put("bdb_id", endpointConfig.getBdbId()); + + FaultInjectionClient.TriggerActionResponse actionResponse = null; + String action = "network_failure"; + try { + log.info("Triggering {}", action); + actionResponse = faultClient.triggerAction(action, params); + } catch (IOException e) { + fail("Fault Injection Server error:" + e.getMessage()); + } + log.info("Action id: {}", actionResponse.getActionId()); + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java new file mode 100644 index 0000000000..7e305ab766 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/RedisEntraIDManagedIdentityIntegrationTests.java @@ -0,0 +1,81 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; + +import java.util.Collections; +import java.util.Set; +import java.util.UUID; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.authentication.entraid.ManagedIdentityInfo.UserManagedIdentityType; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisPooled; + +public class RedisEntraIDManagedIdentityIntegrationTests { + private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class); + + private static EntraIDTestContext testCtx; + private static EndpointConfig endpointConfig; + private static HostAndPort hnp; + private static Set managedIdentityAudience = Collections + .singleton("https://redis.azure.com"); + + @BeforeClass + public static void before() { + try { + testCtx = EntraIDTestContext.DEFAULT; + endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl"); + hnp = endpointConfig.getHostAndPort(); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + @Test + public void withUserAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .userAssignedManagedIdentity(UserManagedIdentityType.OBJECT_ID, + testCtx.getUserAssignedManagedIdentity()) + .scopes(managedIdentityAudience).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } + + // T.1.1 + // Verify authentication using Azure AD with managed identities + @Test + public void withSystemAssignedId_azureManagedIdentityIntegrationTest() { + TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder() + .systemAssignedManagedIdentity().scopes(managedIdentityAudience).build(); + + DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) { + String key = UUID.randomUUID().toString(); + jedis.set(key, "value"); + assertEquals("value", jedis.get(key)); + jedis.del(key); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java new file mode 100644 index 0000000000..cd7e8eb6f4 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationClusterIntegrationTests.java @@ -0,0 +1,131 @@ +package redis.clients.jedis.authentication; + +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.junit.Test; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; +import redis.clients.jedis.Connection; +import redis.clients.jedis.ConnectionPoolConfig; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisCluster; +import redis.clients.jedis.JedisClusterTestBase; + +public class TokenBasedAuthenticationClusterIntegrationTests extends JedisClusterTestBase { + + @Test + public void testClusterInitWithAuthXManager() { + IdentityProviderConfig idpConfig = new IdentityProviderConfig() { + @Override + public IdentityProvider getProvider() { + return new IdentityProvider() { + @Override + public Token requestToken() { + return new SimpleToken("default", "cluster", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + null); + } + }; + } + }; + AuthXManager manager = new AuthXManager(EntraIDTokenAuthConfigBuilder.builder() + .lowerRefreshBoundMillis(1000).identityProviderConfig(idpConfig).build()); + + HostAndPort hp = HostAndPorts.getClusterServers().get(0); + int defaultDirections = 5; + JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(manager).build(); + + ConnectionPoolConfig DEFAULT_POOL_CONFIG = new ConnectionPoolConfig(); + try (JedisCluster jc = new JedisCluster(hp, config, defaultDirections, + DEFAULT_POOL_CONFIG)) { + + assertEquals("OK", jc.set("foo", "bar")); + assertEquals("bar", jc.get("foo")); + assertEquals(1, jc.del("foo")); + } + } + + @Test + public void testClusterWithReAuth() throws InterruptedException, ExecutionException { + IdentityProviderConfig idpConfig = new IdentityProviderConfig() { + @Override + public IdentityProvider getProvider() { + return new IdentityProvider() { + @Override + public Token requestToken() { + return new SimpleToken("default", "cluster", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + null); + } + }; + } + }; + AuthXManager authXManager = new AuthXManager(EntraIDTokenAuthConfigBuilder.builder() + .lowerRefreshBoundMillis(4600).identityProviderConfig(idpConfig).build()); + + authXManager = spy(authXManager); + + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + HostAndPort hp = HostAndPorts.getClusterServers().get(0); + JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(authXManager) + .build(); + + ExecutorService executorService = Executors.newFixedThreadPool(2); + CountDownLatch latch = new CountDownLatch(1); + try (JedisCluster jc = new JedisCluster(Collections.singleton(hp), config)) { + Runnable task = () -> { + while (latch.getCount() > 0) { + assertEquals("OK", jc.set("foo", "bar")); + } + }; + Future task1 = executorService.submit(task); + Future task2 = executorService.submit(task); + + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .until(connections::size, greaterThanOrEqualTo(2)); + + connections.forEach(conn -> { + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .untilAsserted(() -> verify(conn, atLeast(2)).reAuthenticate()); + }); + latch.countDown(); + task1.get(); + task2.get(); + } finally { + latch.countDown(); + executorService.shutdown(); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java new file mode 100644 index 0000000000..9060f80719 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationIntegrationTests.java @@ -0,0 +1,252 @@ +package redis.clients.jedis.authentication; + +import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS; +import static org.awaitility.Durations.ONE_SECOND; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.jedis.CommandArguments; +import redis.clients.jedis.Connection; +/* */ +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPorts; +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.JedisPubSub; +import redis.clients.jedis.RedisProtocol; +import redis.clients.jedis.Protocol.Command; +import redis.clients.jedis.exceptions.JedisException; + +public class TokenBasedAuthenticationIntegrationTests { + private static final Logger log = LoggerFactory + .getLogger(TokenBasedAuthenticationIntegrationTests.class); + + private static EndpointConfig endpointConfig; + + @BeforeClass + public static void before() { + try { + endpointConfig = HostAndPorts.getRedisEndpoint("standalone0"); + } catch (IllegalArgumentException e) { + log.warn("Skipping test because no Redis endpoint is configured"); + org.junit.Assume.assumeTrue(false); + } + } + + @Test + public void testJedisPooledForInitialAuth() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.get("key1"); + } + } + + @Test + public void testJedisPooledReauth() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password, + System.currentTimeMillis() + 5000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager) + .build(); + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + AtomicBoolean stop = new AtomicBoolean(false); + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.submit(() -> { + while (!stop.get()) { + jedis.get("key1"); + } + }); + + for (Connection connection : connections) { + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> { + verify(connection, atLeast(3)).reAuthenticate(); + }); + } + stop.set(true); + executor.shutdown(); + } + } + + @Test + public void testPubSubForInitialAuth() throws InterruptedException { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).protocol(RedisProtocol.RESP3).build(); + + JedisPubSub pubSub = new JedisPubSub() { + public void onSubscribe(String channel, int subscribedChannels) { + this.unsubscribe(); + } + }; + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + jedis.subscribe(pubSub, "channel1"); + } + } + + @Test + public void testJedisPubSubReauth() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password, + System.currentTimeMillis() + 5000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build(); + + AuthXManager authXManager = new AuthXManager(tokenAuthConfig); + authXManager = spy(authXManager); + List connections = new ArrayList<>(); + doAnswer(invocation -> { + Connection connection = spy((Connection) invocation.getArgument(0)); + invocation.getArguments()[0] = connection; + connections.add(connection); + Object result = invocation.callRealMethod(); + return result; + }).when(authXManager).addConnection(any(Connection.class)); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager) + .protocol(RedisProtocol.RESP3).build(); + + JedisPubSub pubSub = new JedisPubSub() { + }; + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.submit(() -> { + jedis.subscribe(pubSub, "channel1"); + }); + + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND) + .until(pubSub::getSubscribedChannels, greaterThan(0)); + + assertEquals(1, connections.size()); + for (Connection connection : connections) { + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> { + ArgumentCaptor captor = ArgumentCaptor.forClass(CommandArguments.class); + + verify(connection, atLeast(3)).sendCommand(captor.capture()); + assertThat(captor.getAllValues().stream() + .filter((item) -> item.getCommand() == Command.AUTH).count(), + greaterThan(3L)); + + }); + } + pubSub.unsubscribe(); + executor.shutdown(); + } + } + + @Test + public void testJedisPubSubWithResp2() { + String user = "default"; + String password = endpointConfig.getPassword(); + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password, + System.currentTimeMillis() + 100000, System.currentTimeMillis(), null)); + + IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class); + when(idProviderConfig.getProvider()).thenReturn(idProvider); + + TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder() + .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F) + .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build(); + + JedisClientConfig clientConfig = DefaultJedisClientConfig.builder() + .authXManager(new AuthXManager(tokenAuthConfig)).build(); + + try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) { + JedisPubSub pubSub = new JedisPubSub() { + }; + JedisException e = assertThrows(JedisException.class, + () -> jedis.subscribe(pubSub, "channel1")); + assertEquals( + "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!", + e.getMessage()); + } + } +} diff --git a/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java new file mode 100644 index 0000000000..802dda2b86 --- /dev/null +++ b/src/test/java/redis/clients/jedis/authentication/TokenBasedAuthenticationUnitTests.java @@ -0,0 +1,342 @@ +package redis.clients.jedis.authentication; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.awaitility.Awaitility.await; +import static org.awaitility.Durations.*; +import static org.hamcrest.CoreMatchers.either; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +import org.hamcrest.Matchers; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedConstruction; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.core.TokenListener; +import redis.clients.authentication.core.TokenManager; +import redis.clients.authentication.core.TokenManagerConfig; +import redis.clients.authentication.core.TokenManagerConfig.RetryPolicy; +import redis.clients.jedis.ConnectionPool; +import redis.clients.jedis.EndpointConfig; +import redis.clients.jedis.HostAndPort; + +public class TokenBasedAuthenticationUnitTests { + + private HostAndPort hnp = new HostAndPort("localhost", 6379); + private EndpointConfig endpoint = new EndpointConfig(hnp, null, null, false); + + @Test + public void testJedisAuthXManagerInstance() { + TokenManagerConfig tokenManagerConfig = mock(TokenManagerConfig.class); + IdentityProviderConfig identityProviderConfig = mock(IdentityProviderConfig.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + + when(identityProviderConfig.getProvider()).thenReturn(identityProvider); + + try (MockedConstruction mockedConstructor = mockConstruction(TokenManager.class, + (mock, context) -> { + assertEquals(identityProvider, context.arguments().get(0)); + assertEquals(tokenManagerConfig, context.arguments().get(1)); + })) { + + new AuthXManager(new TokenAuthConfig(tokenManagerConfig, identityProviderConfig)); + } + } + + @Test + public void withExpirationRefreshRatio_testJedisAuthXManagerTriggersEvict() throws Exception { + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken("default", "password", System.currentTimeMillis() + 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); + + TokenManager tokenManager = new TokenManager(idProvider, + new TokenManagerConfig(0.4F, 100, 1000, new RetryPolicy(1, 1))); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); + + AtomicInteger numberOfEvictions = new AtomicInteger(0); + ConnectionPool pool = new ConnectionPool(hnp, + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { + @Override + public void evict() throws Exception { + numberOfEvictions.incrementAndGet(); + super.evict(); + } + }; + + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); + } + + public void withLowerRefreshBounds_testJedisAuthXManagerTriggersEvict() throws Exception { + + IdentityProvider idProvider = mock(IdentityProvider.class); + when(idProvider.requestToken()) + .thenReturn(new SimpleToken("default", "password", System.currentTimeMillis() + 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "default"))); + + TokenManager tokenManager = new TokenManager(idProvider, + new TokenManagerConfig(0.9F, 600, 1000, new RetryPolicy(1, 1))); + AuthXManager jedisAuthXManager = new AuthXManager(tokenManager); + + AtomicInteger numberOfEvictions = new AtomicInteger(0); + ConnectionPool pool = new ConnectionPool(endpoint.getHostAndPort(), + endpoint.getClientConfigBuilder().authXManager(jedisAuthXManager).build()) { + @Override + public void evict() throws Exception { + numberOfEvictions.incrementAndGet(); + super.evict(); + } + }; + + await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) + .until(numberOfEvictions::get, Matchers.greaterThanOrEqualTo(1)); + } + + public static class TokenManagerConfigWrapper extends TokenManagerConfig { + int lower; + float ratio; + + public TokenManagerConfigWrapper() { + super(0, 0, 0, null); + } + + @Override + public int getLowerRefreshBoundMillis() { + return lower; + } + + @Override + public float getExpirationRefreshRatio() { + return ratio; + } + + @Override + public RetryPolicy getRetryPolicy() { + return new RetryPolicy(1, 1); + } + } + + @Test + public void testCalculateRenewalDelay() { + long delay = 0; + long duration = 0; + long issueDate; + long expireDate; + + TokenManagerConfigWrapper config = new TokenManagerConfigWrapper(); + TokenManager manager = new TokenManager(() -> null, config); + + duration = 5000; + config.lower = 2000; + config.ratio = 0.5F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, + lessThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + + duration = 10000; + config.lower = 8000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, + lessThanOrEqualTo(Math.min(duration - config.lower, (long) (duration * config.ratio)))); + + duration = 10000; + config.lower = 10000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 0; + config.lower = 5000; + config.ratio = 0.2F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 10000; + config.lower = 1000; + config.ratio = 0.00001F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertEquals(0, delay); + + duration = 10000; + config.lower = 1000; + config.ratio = 0.0001F; + issueDate = System.currentTimeMillis(); + expireDate = issueDate + duration; + + delay = manager.calculateRenewalDelay(expireDate, issueDate); + + assertThat(delay, either(is(0L)).or(is(1L))); + } + + @Test + public void testAuthXManagerReceivesNewToken() + throws InterruptedException, ExecutionException, TimeoutException { + + IdentityProvider identityProvider = () -> new SimpleToken("user1", "tokenVal", + System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(), + Collections.singletonMap("oid", "user1")); + + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 2000, new RetryPolicy(1, 1))); + + AuthXManager manager = spy(new AuthXManager(tokenManager)); + + final Token[] tokenHolder = new Token[1]; + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + tokenHolder[0] = (Token) args[0]; + return null; + }).when(manager).authenticateConnections(any()); + + manager.start(); + assertEquals(tokenHolder[0].getValue(), "tokenVal"); + } + + @Test + public void testBlockForInitialTokenWhenException() { + String exceptionMessage = "Test exception from identity provider!"; + IdentityProvider identityProvider = () -> { + throw new RuntimeException(exceptionMessage); + }; + + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 2000, new TokenManagerConfig.RetryPolicy(5, 100))); + + AuthXManager manager = new AuthXManager(tokenManager); + JedisAuthenticationException e = assertThrows(JedisAuthenticationException.class, + () -> manager.start()); + + assertEquals(exceptionMessage, e.getCause().getCause().getMessage()); + } + + @Test + public void testBlockForInitialTokenWhenHangs() { + String exceptionMessage = "AuthXManager failed to start!"; + CountDownLatch latch = new CountDownLatch(1); + IdentityProvider identityProvider = () -> { + try { + latch.await(); + } catch (InterruptedException e) { + } + return null; + }; + + TokenManager tokenManager = new TokenManager(identityProvider, + new TokenManagerConfig(0.7F, 200, 1000, new TokenManagerConfig.RetryPolicy(2, 100))); + + AuthXManager manager = new AuthXManager(tokenManager); + JedisAuthenticationException e = assertThrows(JedisAuthenticationException.class, + () -> manager.start()); + + latch.countDown(); + assertEquals(exceptionMessage, e.getMessage()); + } + + @Test + public void testTokenManagerWithFailingTokenRequest() + throws InterruptedException, ExecutionException, TimeoutException { + int numberOfRetries = 5; + CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + + IdentityProvider identityProvider = mock(IdentityProvider.class); + when(identityProvider.requestToken()).thenAnswer(invocation -> { + requesLatch.countDown(); + if (requesLatch.getCount() > 0) { + throw new RuntimeException("Test exception from identity provider!"); + } + return new SimpleToken("user1", "tokenValX", System.currentTimeMillis() + 50 * 1000, + System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); + }); + + ArgumentCaptor argument = ArgumentCaptor.forClass(Token.class); + + TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, + 2000, new TokenManagerConfig.RetryPolicy(numberOfRetries - 1, 100))); + + TokenListener listener = mock(TokenListener.class); + tokenManager.start(listener, false); + requesLatch.await(); + await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_HUNDRED_MILLISECONDS) + .untilAsserted(() -> verify(listener).onTokenRenewed(argument.capture())); + verify(identityProvider, times(numberOfRetries)).requestToken(); + verify(listener, never()).onError(any()); + assertEquals("tokenValX", argument.getValue().getValue()); + } + + @Test + public void testTokenManagerWithHangingTokenRequest() + throws InterruptedException, ExecutionException, TimeoutException { + int sleepDuration = 200; + int executionTimeout = 100; + int tokenLifetime = 50 * 1000; + int numberOfRetries = 5; + CountDownLatch requesLatch = new CountDownLatch(numberOfRetries); + + IdentityProvider identityProvider = () -> { + requesLatch.countDown(); + if (requesLatch.getCount() > 0) { + try { + Thread.sleep(sleepDuration); + } catch (InterruptedException e) { + } + return null; + } + return new SimpleToken("user1", "tokenValX", System.currentTimeMillis() + tokenLifetime, + System.currentTimeMillis(), Collections.singletonMap("oid", "user1")); + }; + + TokenManager tokenManager = new TokenManager(identityProvider, new TokenManagerConfig(0.7F, 200, + executionTimeout, new TokenManagerConfig.RetryPolicy(numberOfRetries, 100))); + + AuthXManager manager = spy(new AuthXManager(tokenManager)); + AuthXEventListener listener = mock(AuthXEventListener.class); + manager.setListener(listener); + manager.start(); + requesLatch.await(); + verify(listener, never()).onIdentityProviderError(any()); + verify(listener, never()).onConnectionAuthenticationError(any()); + + await().atMost(2, TimeUnit.SECONDS).untilAsserted(() -> { + verify(manager, times(1)).authenticateConnections(any()); + }); + } +}