diff --git a/pom.xml b/pom.xml
index 5b9cf55500..add6b378ea 100644
--- a/pom.xml
+++ b/pom.xml
@@ -75,6 +75,12 @@
2.11.0
+
+ redis.clients.authentication
+ redis-authx-core
+ 0.1.1-beta1
+
+
@@ -150,6 +156,13 @@
test
+
+ redis.clients.authentication
+ redis-authx-entraid
+ 0.1.1-beta1
+ test
+
+
io.github.resilience4j
diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java
index 2860866c6e..de473d0b8e 100644
--- a/src/main/java/redis/clients/jedis/Connection.java
+++ b/src/main/java/redis/clients/jedis/Connection.java
@@ -14,12 +14,14 @@
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
+import java.util.concurrent.atomic.AtomicReference;
import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.Protocol.Keyword;
import redis.clients.jedis.annots.Experimental;
import redis.clients.jedis.args.ClientAttributeOption;
import redis.clients.jedis.args.Rawable;
+import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.commands.ProtocolCommand;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisDataException;
@@ -44,6 +46,8 @@ public class Connection implements Closeable {
private String strVal;
protected String server;
protected String version;
+ private AtomicReference currentCredentials = new AtomicReference<>(null);
+ private AuthXManager authXManager;
public Connection() {
this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT);
@@ -63,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC
public Connection(final JedisSocketFactory socketFactory) {
this.socketFactory = socketFactory;
+ this.authXManager = null;
}
public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) {
@@ -93,8 +98,8 @@ public String toIdentityString() {
SocketAddress remoteAddr = socket.getRemoteSocketAddress();
SocketAddress localAddr = socket.getLocalSocketAddress();
if (remoteAddr != null) {
- strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id,
- localAddr, (broken ? '!' : '-'), remoteAddr);
+ strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, localAddr,
+ (broken ? '!' : '-'), remoteAddr);
} else if (localAddr != null) {
strVal = String.format("%s{id: 0x%X, L:%s}", className, id, localAddr);
} else {
@@ -438,8 +443,8 @@ private static boolean validateClientInfo(String info) {
for (int i = 0; i < info.length(); i++) {
char c = info.charAt(i);
if (c < '!' || c > '~') {
- throw new JedisValidationException("client info cannot contain spaces, "
- + "newlines or special characters.");
+ throw new JedisValidationException(
+ "client info cannot contain spaces, " + "newlines or special characters.");
}
}
return true;
@@ -451,7 +456,13 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {
protocol = config.getRedisProtocol();
- final Supplier credentialsProvider = config.getCredentialsProvider();
+ Supplier credentialsProvider = config.getCredentialsProvider();
+
+ authXManager = config.getAuthXManager();
+ if (authXManager != null) {
+ credentialsProvider = authXManager;
+ }
+
if (credentialsProvider instanceof RedisCredentialsProvider) {
final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider;
try {
@@ -469,7 +480,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {
String clientName = config.getClientName();
if (clientName != null && validateClientInfo(clientName)) {
- fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
+ fireAndForgetMsg
+ .add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
}
ClientSetInfoConfig setInfoConfig = config.getClientSetInfoConfig();
@@ -525,12 +537,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
if (protocol != null && credentials != null && credentials.getUser() != null) {
byte[] rawPass = encodeToBytes(credentials.getPassword());
try {
- helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass);
+ helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(),
+ encode(credentials.getUser()), rawPass);
} finally {
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
}
} else {
- auth(credentials);
+ authenticate(credentials);
helloResult = protocol == null ? null : hello(encode(protocol.version()));
}
if (helloResult != null) {
@@ -542,9 +555,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
// handled in RedisCredentialsProvider.cleanUp()
}
- private void auth(RedisCredentials credentials) {
+ public void setCredentials(RedisCredentials credentials) {
+ currentCredentials.set(credentials);
+ }
+
+ private String authenticate(RedisCredentials credentials) {
if (credentials == null || credentials.getPassword() == null) {
- return;
+ return null;
}
byte[] rawPass = encodeToBytes(credentials.getPassword());
try {
@@ -556,7 +573,11 @@ private void auth(RedisCredentials credentials) {
} finally {
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
}
- getStatusCodeReply();
+ return getStatusCodeReply();
+ }
+
+ public String reAuthenticate() {
+ return authenticate(currentCredentials.getAndSet(null));
}
protected Map hello(byte[]... args) {
@@ -585,4 +606,12 @@ public boolean ping() {
}
return true;
}
+
+ protected boolean isTokenBasedAuthenticationEnabled() {
+ return authXManager != null;
+ }
+
+ protected AuthXManager getAuthXManager() {
+ return authXManager;
+ }
}
diff --git a/src/main/java/redis/clients/jedis/ConnectionFactory.java b/src/main/java/redis/clients/jedis/ConnectionFactory.java
index cc53df56f0..7440417152 100644
--- a/src/main/java/redis/clients/jedis/ConnectionFactory.java
+++ b/src/main/java/redis/clients/jedis/ConnectionFactory.java
@@ -6,7 +6,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.util.function.Supplier;
+
import redis.clients.jedis.annots.Experimental;
+import redis.clients.jedis.authentication.AuthXManager;
+import redis.clients.jedis.authentication.JedisAuthenticationException;
+import redis.clients.jedis.authentication.AuthXEventListener;
import redis.clients.jedis.csc.Cache;
import redis.clients.jedis.csc.CacheConnection;
import redis.clients.jedis.exceptions.JedisException;
@@ -20,28 +25,52 @@ public class ConnectionFactory implements PooledObjectFactory {
private final JedisSocketFactory jedisSocketFactory;
private final JedisClientConfig clientConfig;
- private Cache clientSideCache = null;
+ private final Cache clientSideCache;
+ private final Supplier objectMaker;
+
+ private final AuthXEventListener authXEventListener;
public ConnectionFactory(final HostAndPort hostAndPort) {
- this.clientConfig = DefaultJedisClientConfig.builder().build();
- this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort);
+ this(hostAndPort, DefaultJedisClientConfig.builder().build(), null);
}
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) {
- this.clientConfig = clientConfig;
- this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
+ this(hostAndPort, clientConfig, null);
}
@Experimental
- public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, Cache csCache) {
- this.clientConfig = clientConfig;
- this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
- this.clientSideCache = csCache;
+ public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig,
+ Cache csCache) {
+ this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache);
}
- public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) {
- this.clientConfig = clientConfig;
+ public ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
+ final JedisClientConfig clientConfig) {
+ this(jedisSocketFactory, clientConfig, null);
+ }
+
+ private ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
+ final JedisClientConfig clientConfig, Cache csCache) {
+
this.jedisSocketFactory = jedisSocketFactory;
+ this.clientSideCache = csCache;
+ this.clientConfig = clientConfig;
+
+ AuthXManager authXManager = clientConfig.getAuthXManager();
+ if (authXManager == null) {
+ this.objectMaker = connectionSupplier();
+ this.authXEventListener = AuthXEventListener.NOOP_LISTENER;
+ } else {
+ Supplier supplier = connectionSupplier();
+ this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get());
+ this.authXEventListener = authXManager.getListener();
+ authXManager.start();
+ }
+ }
+
+ private Supplier connectionSupplier() {
+ return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig)
+ : () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
}
@Override
@@ -64,8 +93,7 @@ public void destroyObject(PooledObject pooledConnection) throws Exce
@Override
public PooledObject makeObject() throws Exception {
try {
- Connection jedis = clientSideCache == null ? new Connection(jedisSocketFactory, clientConfig)
- : new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
+ Connection jedis = objectMaker.get();
return new DefaultPooledObject<>(jedis);
} catch (JedisException je) {
logger.debug("Error while makeObject", je);
@@ -76,6 +104,8 @@ public PooledObject makeObject() throws Exception {
@Override
public void passivateObject(PooledObject pooledConnection) throws Exception {
// TODO maybe should select db 0? Not sure right now.
+ Connection jedis = pooledConnection.getObject();
+ reAuthenticate(jedis);
}
@Override
@@ -83,10 +113,31 @@ public boolean validateObject(PooledObject pooledConnection) {
final Connection jedis = pooledConnection.getObject();
try {
// check HostAndPort ??
- return jedis.isConnected() && jedis.ping();
+ if (!jedis.isConnected()) {
+ return false;
+ }
+ reAuthenticate(jedis);
+ return jedis.ping();
} catch (final Exception e) {
logger.warn("Error while validating pooled Connection object.", e);
return false;
}
}
+
+ private void reAuthenticate(Connection jedis) throws Exception {
+ try {
+ String result = jedis.reAuthenticate();
+ if (result != null && !result.equals("OK")) {
+ String msg = "Re-authentication failed with server response: " + result;
+ Exception failedAuth = new JedisAuthenticationException(msg);
+ logger.error(failedAuth.getMessage(), failedAuth);
+ authXEventListener.onConnectionAuthenticationError(failedAuth);
+ return;
+ }
+ } catch (Exception e) {
+ logger.error("Error while re-authenticating connection", e);
+ authXEventListener.onConnectionAuthenticationError(e);
+ throw e;
+ }
+ }
}
diff --git a/src/main/java/redis/clients/jedis/ConnectionPool.java b/src/main/java/redis/clients/jedis/ConnectionPool.java
index 40d4861f98..2ae1401081 100644
--- a/src/main/java/redis/clients/jedis/ConnectionPool.java
+++ b/src/main/java/redis/clients/jedis/ConnectionPool.java
@@ -2,19 +2,27 @@
import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
+
import redis.clients.jedis.annots.Experimental;
+import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.csc.Cache;
+import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.Pool;
public class ConnectionPool extends Pool {
+ private AuthXManager authXManager;
+
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig));
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
@Experimental
- public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) {
+ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
+ Cache clientSideCache) {
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache));
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
public ConnectionPool(PooledObjectFactory factory) {
@@ -24,12 +32,14 @@ public ConnectionPool(PooledObjectFactory factory) {
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
GenericObjectPoolConfig poolConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig);
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
@Experimental
- public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache,
- GenericObjectPoolConfig poolConfig) {
+ public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
+ Cache clientSideCache, GenericObjectPoolConfig poolConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig);
+ attachAuthenticationListener(clientConfig.getAuthXManager());
}
public ConnectionPool(PooledObjectFactory factory,
@@ -43,4 +53,29 @@ public Connection getResource() {
conn.setHandlingPool(this);
return conn;
}
+
+ @Override
+ public void close() {
+ try {
+ if (authXManager != null) {
+ authXManager.stop();
+ }
+ } finally {
+ super.close();
+ }
+ }
+
+ private void attachAuthenticationListener(AuthXManager authXManager) {
+ this.authXManager = authXManager;
+ if (authXManager != null) {
+ authXManager.addPostAuthenticationHook(token -> {
+ try {
+ // this is to trigger validations on each connection via ConnectionFactory
+ evict();
+ } catch (Exception e) {
+ throw new JedisException("Failed to evict connections from pool", e);
+ }
+ });
+ }
+ }
}
diff --git a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
index d5468f3a46..7d41d9d28a 100644
--- a/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
+++ b/src/main/java/redis/clients/jedis/DefaultJedisClientConfig.java
@@ -5,6 +5,8 @@
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
+import redis.clients.jedis.authentication.AuthXManager;
+
public final class DefaultJedisClientConfig implements JedisClientConfig {
private final RedisProtocol redisProtocol;
@@ -29,6 +31,8 @@ public final class DefaultJedisClientConfig implements JedisClientConfig {
private final boolean readOnlyForRedisClusterReplicas;
+ private final AuthXManager authXManager;
+
private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) {
this.redisProtocol = builder.redisProtocol;
this.connectionTimeoutMillis = builder.connectionTimeoutMillis;
@@ -45,6 +49,7 @@ private DefaultJedisClientConfig(DefaultJedisClientConfig.Builder builder) {
this.hostAndPortMapper = builder.hostAndPortMapper;
this.clientSetInfoConfig = builder.clientSetInfoConfig;
this.readOnlyForRedisClusterReplicas = builder.readOnlyForRedisClusterReplicas;
+ this.authXManager = builder.authXManager;
}
@Override
@@ -83,6 +88,11 @@ public Supplier getCredentialsProvider() {
return credentialsProvider;
}
+ @Override
+ public AuthXManager getAuthXManager() {
+ return authXManager;
+ }
+
@Override
public int getDatabase() {
return database;
@@ -163,6 +173,8 @@ public static class Builder {
private boolean readOnlyForRedisClusterReplicas = false;
+ private AuthXManager authXManager = null;
+
private Builder() {
}
@@ -279,6 +291,30 @@ public Builder readOnlyForRedisClusterReplicas() {
this.readOnlyForRedisClusterReplicas = true;
return this;
}
+
+ public Builder authXManager(AuthXManager authXManager) {
+ this.authXManager = authXManager;
+ return this;
+ }
+
+ public Builder from(JedisClientConfig instance) {
+ this.redisProtocol = instance.getRedisProtocol();
+ this.connectionTimeoutMillis = instance.getConnectionTimeoutMillis();
+ this.socketTimeoutMillis = instance.getSocketTimeoutMillis();
+ this.blockingSocketTimeoutMillis = instance.getBlockingSocketTimeoutMillis();
+ this.credentialsProvider = instance.getCredentialsProvider();
+ this.database = instance.getDatabase();
+ this.clientName = instance.getClientName();
+ this.ssl = instance.isSsl();
+ this.sslSocketFactory = instance.getSslSocketFactory();
+ this.sslParameters = instance.getSslParameters();
+ this.hostnameVerifier = instance.getHostnameVerifier();
+ this.hostAndPortMapper = instance.getHostAndPortMapper();
+ this.clientSetInfoConfig = instance.getClientSetInfoConfig();
+ this.readOnlyForRedisClusterReplicas = instance.isReadOnlyForRedisClusterReplicas();
+ this.authXManager = instance.getAuthXManager();
+ return this;
+ }
}
public static DefaultJedisClientConfig create(int connectionTimeoutMillis, int soTimeoutMillis,
@@ -328,6 +364,9 @@ public static DefaultJedisClientConfig copyConfig(JedisClientConfig copy) {
if (copy.isReadOnlyForRedisClusterReplicas()) {
builder.readOnlyForRedisClusterReplicas();
}
+
+ builder.authXManager(copy.getAuthXManager());
+
return builder.build();
}
}
diff --git a/src/main/java/redis/clients/jedis/JedisClientConfig.java b/src/main/java/redis/clients/jedis/JedisClientConfig.java
index 8bd18b5aaa..ce7fd82de4 100644
--- a/src/main/java/redis/clients/jedis/JedisClientConfig.java
+++ b/src/main/java/redis/clients/jedis/JedisClientConfig.java
@@ -5,6 +5,8 @@
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocketFactory;
+import redis.clients.jedis.authentication.AuthXManager;
+
public interface JedisClientConfig {
default RedisProtocol getRedisProtocol() {
@@ -50,6 +52,10 @@ default Supplier getCredentialsProvider() {
new DefaultRedisCredentials(getUser(), getPassword()));
}
+ default AuthXManager getAuthXManager() {
+ return null;
+ }
+
default int getDatabase() {
return Protocol.DEFAULT_DATABASE;
}
diff --git a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java
index ec63c5206a..9462527c0f 100644
--- a/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java
+++ b/src/main/java/redis/clients/jedis/JedisClusterInfoCache.java
@@ -103,6 +103,9 @@ public JedisClusterInfoCache(final JedisClientConfig clientConfig, Cache clientS
this.clientConfig = clientConfig;
this.clientSideCache = clientSideCache;
this.startNodes = startNodes;
+ if (clientConfig.getAuthXManager() != null) {
+ clientConfig.getAuthXManager().start();
+ }
if (topologyRefreshPeriod != null) {
logger.info("Cluster topology refresh start, period: {}, startNodes: {}", topologyRefreshPeriod, startNodes);
topologyRefreshExecutor = Executors.newSingleThreadScheduledExecutor();
diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java
index bf9d0a32c5..91fee36c58 100644
--- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java
+++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java
@@ -4,6 +4,7 @@
import java.util.Arrays;
import java.util.List;
+import java.util.function.Consumer;
import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.exceptions.JedisException;
@@ -12,7 +13,8 @@
public abstract class JedisPubSubBase {
private int subscribedChannels = 0;
- private volatile Connection client;
+ private final JedisSafeAuthenticator authenticator = new JedisSafeAuthenticator();
+ private final Consumer