Skip to content

Commit

Permalink
Merge pull request #2562 from atlanhq/optimise_scrub_search
Browse files Browse the repository at this point in the history
Optimise scrub search
  • Loading branch information
ektavarma10 authored Dec 1, 2023
2 parents 63e6296 + 4df3b2e commit 1266e82
Showing 1 changed file with 180 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,26 @@
import org.apache.atlas.plugin.service.RangerBasePlugin;
import org.apache.atlas.plugin.util.RangerPerfTracer;

import java.util.*;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.ArrayList;
import java.util.Set;
import java.util.Map;
import java.util.HashMap;
import java.util.UUID;
import java.util.Collection;
import java.util.Optional;
import java.util.Objects;


import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.stream.Collectors;

import static org.apache.atlas.authorization.atlas.authorizer.RangerAtlasAuthorizerUtil.*;
import static org.apache.atlas.authorize.AtlasAuthorizationUtils.getCurrentUserGroups;
Expand All @@ -74,6 +93,8 @@ public class RangerAtlasAuthorizer implements AtlasAuthorizer {
add(AtlasPrivilege.ENTITY_UPDATE_CLASSIFICATION);
}};

private static final ExecutorService classificationAccessThreadpool = Executors.newCachedThreadPool();
private static final ExecutorService entityAccessThreadpool = Executors.newCachedThreadPool();
@Override
public void init() {
if (LOG.isDebugEnabled()) {
Expand Down Expand Up @@ -578,30 +599,34 @@ public void scrubSearchResults(AtlasSearchResultScrubRequest request, boolean is
if (LOG.isDebugEnabled())
LOG.debug("==> scrubSearchResults(" + request + " " + isScrubAuditEnabled);
RangerPerfTracer perf = null;
long startTime = System.currentTimeMillis();
try {
if (RangerPerfTracer.isPerfTraceEnabled(PERF_LOG))
perf = RangerPerfTracer.getPerfTracer(PERF_LOG, "RangerAtlasAuthorizer.scrubSearchResults(" + request + ")");
AtlasSearchResult result = request.getSearchResult();
List<AtlasEntityHeader> entitiesToCheck = new ArrayList<>();

long startTime = System.currentTimeMillis();
if (CollectionUtils.isNotEmpty(result.getEntities())) {
for (AtlasEntityHeader entity : result.getEntities()) {
checkAccessAndScrub(entity, request, isScrubAuditEnabled);
}
LOG.info("scrubSearchResults ended for entites: " + (System.currentTimeMillis()-startTime));
entitiesToCheck.addAll(result.getEntities());
}

if (CollectionUtils.isNotEmpty(result.getFullTextResult())) {
for (AtlasSearchResult.AtlasFullTextResult fullTextResult : result.getFullTextResult()) {
if (fullTextResult != null)
checkAccessAndScrub(fullTextResult.getEntity(), request, isScrubAuditEnabled);
}
entitiesToCheck.addAll(
result.getFullTextResult()
.stream()
.filter(Objects::nonNull)
.map(res -> res.getEntity())
.collect(Collectors.toList())
);
}

if (MapUtils.isNotEmpty(result.getReferredEntities())) {
for (AtlasEntityHeader entity : result.getReferredEntities().values()) {
checkAccessAndScrub(entity, request, isScrubAuditEnabled);
}
entitiesToCheck.addAll(result.getReferredEntities().values());
}

checkAccessAndScrubAsync(entitiesToCheck, request, isScrubAuditEnabled);
} finally {
LOG.info("scrubSearchResults ended in : "+ (System.currentTimeMillis()-startTime));
RangerPerfTracer.log(perf);
}
if (LOG.isDebugEnabled())
Expand All @@ -622,6 +647,38 @@ public void filterTypesDef(AtlasTypesDefFilterRequest request) throws AtlasAutho

}

private void checkAccessAndScrubAsync(List<AtlasEntityHeader> entitiesToCheck, AtlasSearchResultScrubRequest request, boolean isScrubAuditEnabled) throws AtlasAuthorizationException {
LOG.info("Creating futures to check access and scrub " + entitiesToCheck.size() + " entities");
LOG.info("Number of tasks in queue "+((ThreadPoolExecutor)entityAccessThreadpool).getQueue().size());
LOG.info("Number of threads in threadpool "+((ThreadPoolExecutor)entityAccessThreadpool).getActiveCount());
List<CompletableFuture<AtlasAuthorizationException>> completableFutures = entitiesToCheck
.stream()
.map(entity -> CompletableFuture.supplyAsync(() -> {
try {
checkAccessAndScrub(entity, request, isScrubAuditEnabled);
return null;
} catch (AtlasAuthorizationException e) {
return e;
}
}, entityAccessThreadpool))
.collect(Collectors.toList());

// wait for all threads to complete their execution
CompletableFuture.allOf(completableFutures.toArray(new CompletableFuture[0])).join();

// get the first exception from any checkAccessAndScrub calls
Optional<AtlasAuthorizationException> maybeAuthException = completableFutures
.stream()
.map(CompletableFuture::join)
.filter(Objects::nonNull)
.findFirst();

LOG.info("Async check access and scrub is complete");
if (maybeAuthException.isPresent()) {
throw maybeAuthException.get();
}
}

private void filterTypes(AtlasAccessRequest request, List<? extends AtlasBaseTypeDef> typeDefs)throws AtlasAuthorizationException {
if (typeDefs != null) {
for (ListIterator<? extends AtlasBaseTypeDef> iter = typeDefs.listIterator(); iter.hasNext();) {
Expand Down Expand Up @@ -650,68 +707,52 @@ private boolean isAccessAllowed(AtlasEntityAccessRequest request, RangerAtlasAud
if (LOG.isDebugEnabled()) {
LOG.debug("==> isAccessAllowed(" + request + ")");
}
boolean ret = false;

String uuid = UUID.randomUUID().toString();
long startTime = System.currentTimeMillis();
final String uuid = UUID.randomUUID().toString();
LOG.info("start isAccessAllowed : " + startTime + " uuid: " + uuid);
boolean ret = false;

try {
final String action = request.getAction() != null ? request.getAction().getType() : null;
final Set<String> entityTypes = request.getEntityTypeAndAllSuperTypes();
final String entityId = request.getEntityId();
final String classification = request.getClassification() != null ? request.getClassification().getTypeName() : null;
final RangerAccessRequestImpl rangerRequest = new RangerAccessRequestImpl();
final RangerAccessResourceImpl rangerResource = new RangerAccessResourceImpl();
final String ownerUser = request.getEntity() != null ? (String) request.getEntity().getAttribute(RESOURCE_ENTITY_OWNER) : null;

rangerResource.setValue(RESOURCE_ENTITY_TYPE, entityTypes);
rangerResource.setValue(RESOURCE_ENTITY_ID, entityId);
rangerResource.setOwnerUser(ownerUser);
rangerRequest.setAccessType(action);
rangerRequest.setAction(action);
rangerRequest.setUser(request.getUser());
rangerRequest.setUserGroups(request.getUserGroups());
rangerRequest.setClientIPAddress(request.getClientIPAddress());
rangerRequest.setAccessTime(request.getAccessTime());
rangerRequest.setResource(rangerResource);
rangerRequest.setForwardedAddresses(request.getForwardedAddresses());
rangerRequest.setRemoteIPAddress(request.getRemoteIPAddress());

if (AtlasPrivilege.ENTITY_ADD_LABEL.equals(request.getAction()) || AtlasPrivilege.ENTITY_REMOVE_LABEL.equals(request.getAction())) {
rangerResource.setValue(RESOURCE_ENTITY_LABEL, request.getLabel());
} else if (AtlasPrivilege.ENTITY_UPDATE_BUSINESS_METADATA.equals(request.getAction())) {
rangerResource.setValue(RESOURCE_ENTITY_BUSINESS_METADATA, request.getBusinessMetadata());
} else if (StringUtils.isNotEmpty(classification) && CLASSIFICATION_PRIVILEGES.contains(request.getAction())) {
rangerResource.setValue(RESOURCE_CLASSIFICATION, request.getClassificationTypeAndAllSuperTypes(classification));
}

if (CollectionUtils.isNotEmpty(request.getEntityClassifications())) {
Set<AtlasClassification> entityClassifications = request.getEntityClassifications();
Map<String, Object> contextOjb = rangerRequest.getContext();

Set<RangerTagForEval> rangerTagForEval = getRangerServiceTag(entityClassifications);

if (contextOjb == null) {
Map<String, Object> contextOjb1 = new HashMap<String, Object>();
contextOjb1.put("CLASSIFICATIONS", rangerTagForEval);
rangerRequest.setContext(contextOjb1);
} else {
contextOjb.put("CLASSIFICATIONS", rangerTagForEval);
rangerRequest.setContext(contextOjb);
}
List<CompletableFuture<Boolean>> completableFutures = new ArrayList<>();
LOG.info("isAccessAllowed started : " + startTime);

// check authorization for each classification
LOG.info("classification level authorization started: " + (System.currentTimeMillis()-startTime) + "uuid: "+uuid);
LOG.info("start check authrization for each classification: " + (System.currentTimeMillis() - startTime)+ " uuid: " + uuid);
for (AtlasClassification classificationToAuthorize : request.getEntityClassifications()) {
rangerResource.setValue(RESOURCE_ENTITY_CLASSIFICATION, request.getClassificationTypeAndAllSuperTypes(classificationToAuthorize.getTypeName()));
long rangerRequestCreationStartTime = System.currentTimeMillis();
RangerAccessRequestImpl rangerRequest = createRangerAccessRequest(request, classificationToAuthorize, rangerTagForEval);
LOG.info("Time taken to create a ranger request for uuid: "+uuid+ "is "+ (System.currentTimeMillis()-rangerRequestCreationStartTime));
completableFutures.add(CompletableFuture.supplyAsync(()->checkAccess(rangerRequest, auditHandler, uuid), classificationAccessThreadpool));
}

ret = checkAccess(rangerRequest, auditHandler, uuid);
// wait for all threads to complete their execution
CompletableFuture.allOf(completableFutures.toArray(new CompletableFuture[0])).join();
LOG.info("end check authorization for each classification: " + (System.currentTimeMillis() - startTime) + " uuid: " + uuid);


// if all checkAccess calls return true, then ret is true, else it is false
ret = completableFutures
.stream()
.map(CompletableFuture::join)
.allMatch(result -> result == true);

if (!ret) {
break;
}
}
LOG.info("classification level authorization ended: " + (System.currentTimeMillis()-startTime) + "uuid: "+uuid);
} else {

RangerAccessRequestImpl rangerRequest = new RangerAccessRequestImpl();
RangerAccessResourceImpl rangerResource = new RangerAccessResourceImpl();

initRangerRequest(rangerRequest, request);
initRangerResource(rangerResource, request);

rangerRequest.setResource(rangerResource);

rangerResource.setValue(RESOURCE_ENTITY_CLASSIFICATION, ENTITY_NOT_CLASSIFIED );

ret = checkAccess(rangerRequest, auditHandler, uuid);
Expand All @@ -726,10 +767,81 @@ private boolean isAccessAllowed(AtlasEntityAccessRequest request, RangerAtlasAud
if (LOG.isDebugEnabled()) {
LOG.debug("<== isAccessAllowed(" + request + "): " + ret);
}
LOG.info("isAccessAllowed ended: " + (System.currentTimeMillis()-startTime) + "uuid: "+uuid);

return ret;
}

private RangerAccessRequestImpl createRangerAccessRequest(AtlasEntityAccessRequest request,
AtlasClassification classificationToAuthorize,
Set<RangerTagForEval> rangerTagForEval) {

long startTime = System.currentTimeMillis();
LOG.info("createRangerAccessRequest start: " + startTime);

RangerAccessRequestImpl rangerRequest = new RangerAccessRequestImpl();
RangerAccessResourceImpl rangerResource = new RangerAccessResourceImpl();

initRangerRequest(rangerRequest, request);
initRangerResource(rangerResource, request);

rangerResource.setValue(RESOURCE_ENTITY_CLASSIFICATION, request.getClassificationTypeAndAllSuperTypes(classificationToAuthorize.getTypeName()));

rangerRequest.setResource(rangerResource);

setClassificationContextForRanger(rangerTagForEval, rangerRequest);

LOG.info("createRangerAccessRequest end: " + (System.currentTimeMillis() - startTime));

return rangerRequest;

}

private static void setClassificationContextForRanger(Set<RangerTagForEval> rangerTagForEval, RangerAccessRequestImpl rangerRequest) {
Map<String, Object> contextOjb = rangerRequest.getContext();

if (contextOjb == null) {
Map<String, Object> contextOjb1 = new HashMap<String, Object>();
contextOjb1.put("CLASSIFICATIONS", rangerTagForEval);
rangerRequest.setContext(contextOjb1);
} else {
contextOjb.put("CLASSIFICATIONS", rangerTagForEval);
rangerRequest.setContext(contextOjb);
}
}

private void initRangerRequest(RangerAccessRequestImpl rangerRequest, AtlasEntityAccessRequest request) {
final String action = request.getAction() != null ? request.getAction().getType() : null;

rangerRequest.setAccessType(action);
rangerRequest.setAction(action);
rangerRequest.setUser(request.getUser());
rangerRequest.setUserGroups(request.getUserGroups());
rangerRequest.setClientIPAddress(request.getClientIPAddress());
rangerRequest.setAccessTime(request.getAccessTime());
rangerRequest.setForwardedAddresses(request.getForwardedAddresses());
rangerRequest.setRemoteIPAddress(request.getRemoteIPAddress());
}

private void initRangerResource(RangerAccessResourceImpl rangerResource, AtlasEntityAccessRequest request) {
final Set<String> entityTypes = request.getEntityTypeAndAllSuperTypes();
final String entityId = request.getEntityId();
final String ownerUser = request.getEntity() != null ? (String) request.getEntity().getAttribute(RESOURCE_ENTITY_OWNER) : null;
final String classification = request.getClassification() != null ? request.getClassification().getTypeName() : null;

rangerResource.setValue(RESOURCE_ENTITY_TYPE, entityTypes);
rangerResource.setValue(RESOURCE_ENTITY_ID, entityId);
rangerResource.setOwnerUser(ownerUser);

if (AtlasPrivilege.ENTITY_ADD_LABEL.equals(request.getAction()) || AtlasPrivilege.ENTITY_REMOVE_LABEL.equals(request.getAction())) {
rangerResource.setValue(RESOURCE_ENTITY_LABEL, request.getLabel());
} else if (AtlasPrivilege.ENTITY_UPDATE_BUSINESS_METADATA.equals(request.getAction())) {
rangerResource.setValue(RESOURCE_ENTITY_BUSINESS_METADATA, request.getBusinessMetadata());
} else if (StringUtils.isNotEmpty(classification) && CLASSIFICATION_PRIVILEGES.contains(request.getAction())) {
rangerResource.setValue(RESOURCE_CLASSIFICATION, request.getClassificationTypeAndAllSuperTypes(classification));
}

}


private void setClassificationsToRequestContext(Set<AtlasClassification> entityClassifications, RangerAccessRequestImpl rangerRequest) {
Map<String, Object> contextOjb = rangerRequest.getContext();
Expand Down Expand Up @@ -861,7 +973,10 @@ private void checkAccessAndScrub(AtlasEntityHeader entity, AtlasSearchResultScru
}

private void checkAccessAndScrub(AtlasEntityHeader entity, AtlasSearchResultScrubRequest request, boolean isScrubAuditEnabled) throws AtlasAuthorizationException {
LOG.info("Number of tasks waiting "+((ThreadPoolExecutor)entityAccessThreadpool).getQueue().size());
LOG.info("Number of threads in threadpool "+((ThreadPoolExecutor)entityAccessThreadpool).getActiveCount());
if (entity != null && request != null) {
long startTimeOfAccessAndScrub = System.currentTimeMillis();
final AtlasEntityAccessRequest entityAccessRequest = new AtlasEntityAccessRequest(request.getTypeRegistry(), AtlasPrivilege.ENTITY_READ, entity, request.getUser(), request.getUserGroups());

entityAccessRequest.setClientIPAddress(request.getClientIPAddress());
Expand All @@ -870,8 +985,12 @@ private void checkAccessAndScrub(AtlasEntityHeader entity, AtlasSearchResultScru

boolean isEntityAccessAllowed = isScrubAuditEnabled ? isAccessAllowed(entityAccessRequest) : isAccessAllowed(entityAccessRequest, null);
if (!isEntityAccessAllowed) {
long startTime = System.currentTimeMillis();
LOG.info("scrubEntityHeader started" + startTime);
scrubEntityHeader(entity, request.getTypeRegistry());
LOG.info("scrubEntityHeader ended" + (System.currentTimeMillis() - startTime));
}
LOG.info("checkAccessAndScrub ended in " + (System.currentTimeMillis()-startTimeOfAccessAndScrub));
}
}

Expand Down

0 comments on commit 1266e82

Please sign in to comment.