Skip to content

Commit

Permalink
Prevent client cache stampede after invalidation of a client or on st…
Browse files Browse the repository at this point in the history
…artup (keycloak#25217)

Closes keycloak#24202

Signed-off-by: Alexander Schwartz <[email protected]>
  • Loading branch information
ahus1 authored Dec 5, 2023
1 parent e69031d commit e4be3ed
Showing 1 changed file with 44 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1124,30 +1124,46 @@ private void addGroupEventIfAbsent(InvalidationEvent eventToAdd) {

@Override
public ClientModel getClientById(RealmModel realm, String id) {
if (invalidations.contains(id) || listInvalidations.contains(realm.getId())) {
return getClientDelegate().getClientById(realm, id);
} else if (managedApplications.containsKey(id)) {
return managedApplications.get(id);
}
CachedClient cached = cache.get(id, CachedClient.class);
if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null;
}
boolean queryDB = invalidations.contains(id) || listInvalidations.contains(realm.getId());
if (queryDB) { // short-circuit if the client has been potentially invalidated
return getClientDelegate().getClientById(realm, id);
}
ClientModel adapter;
if (cached != null) {
logger.tracev("client by id cache hit: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
} else {
adapter = cache.computeSerialized(session, id, (key, keycloakSession) -> prepareCachedClientById(realm, id));
if (adapter == null) {
return adapter;
}
}
managedApplications.put(id, adapter);
return adapter;
}

private ClientModel prepareCachedClientById(RealmModel realm, String id) {
CachedClient cached = cache.get(id, CachedClient.class);
ClientModel adapter;
if (cached != null && !cached.getRealm().equals(realm.getId())) {
cached = null;
}
if (cached == null) {
Long loaded = cache.getCurrentRevision(id);
ClientModel model = getClientDelegate().getClientById(realm, id);
if (model == null) return null;
ClientModel adapter = cacheClient(realm, model, loaded);
managedApplications.put(id, adapter);
return adapter;
} else if (managedApplications.containsKey(id)) {
return managedApplications.get(id);
if (model == null) {
return null;
}
adapter = cacheClient(realm, model, loaded);
} else {
logger.tracev("client by id cache hit after locking: {0}", cached.getClientId());
adapter = validateCache(realm, cached);
}
ClientModel adapter = validateCache(realm, cached);
managedApplications.put(id, adapter);
return adapter;
}

Expand Down Expand Up @@ -1230,31 +1246,37 @@ public Stream<ClientModel> searchClientsByAuthenticationFlowBindingOverrides(Rea
@Override
public ClientModel getClientByClientId(RealmModel realm, String clientId) {
String cacheKey = getClientByClientIdCacheKey(clientId, realm.getId());
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
String id = null;

boolean queryDB = invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId());
if (queryDB) { // short-circuit if the client has been potentially invalidated
if (invalidations.contains(cacheKey) || listInvalidations.contains(realm.getId())) {
return getClientDelegate().getClientByClientId(realm, clientId);
}
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
if (query != null) {
logger.tracev("client by name cache hit: {0}", clientId);
String id = query.getClients().iterator().next();
return getClientById(realm, id);
} else {
return cache.computeSerialized(session, cacheKey, (key, keycloakSession) -> prepareCachedClientByClientId(realm, clientId, key));
}
}

private ClientModel prepareCachedClientByClientId(RealmModel realm, String clientId, String cacheKey) {
ClientListQuery query = cache.get(cacheKey, ClientListQuery.class);
String id;
if (query == null) {
Long loaded = cache.getCurrentRevision(cacheKey);
ClientModel model = getClientDelegate().getClientByClientId(realm, clientId);
if (model == null) return null;
if (invalidations.contains(model.getId())) return model;
if (model == null) {
return null;
}
id = model.getId();
query = new ClientListQuery(loaded, cacheKey, realm, id);
logger.tracev("adding client by name cache miss: {0}", clientId);
cache.addRevisioned(query, startupRevision);
if (invalidations.contains(model.getId())) {
return model;
}
} else {
id = query.getClients().iterator().next();
if (invalidations.contains(id)) {
return getClientDelegate().getClientByClientId(realm, clientId);
}
}
return getClientById(realm, id);
}
Expand Down

0 comments on commit e4be3ed

Please sign in to comment.