Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
Expand Down Expand Up @@ -52,7 +54,6 @@
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.security.spi.resources.client.ResourceSharingClient;
import org.opensearch.transport.client.Client;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -147,11 +148,13 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId,
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
CommonValue.ML_MODEL_GROUP_INDEX
);
boolean rsClientPresent = ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null;

if (rsClientPresent && user != null && modelAccessControlHelper.modelAccessControlEnabled() && hasModelGroupIndex) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)
&& user != null
&& modelAccessControlHelper.modelAccessControlEnabled()
&& hasModelGroupIndex) {
// RSC fast-path: get accessible group IDs → gate models (IDs or missing)
ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
var rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
rsc.getAccessibleResourceIds(CommonValue.ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> {
SearchSourceBuilder gated = Optional.ofNullable(request.source()).orElseGet(SearchSourceBuilder::new);
gated.query(rewriteQueryBuilderRSC(gated.query(), ids)); // ids may be empty → "missing only"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID;

import org.opensearch.ExceptionsHelper;
Expand All @@ -27,7 +29,6 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction;
Expand Down Expand Up @@ -96,7 +97,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);

// if resource sharing feature is enabled, access will be automatically checked by security plugin, so no need to check again
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
checkForAssociatedModels(modelGroupId, tenantId, wrappedListener);
} else {
validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;

import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
Expand All @@ -27,7 +29,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest;
Expand Down Expand Up @@ -186,7 +187,7 @@ private void validateModelGroupAccess(
) {
// if resource sharing feature is enabled, security plugin will have automatically evaluated access to this model group, hence no
// need to validate again
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build());
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.Collections;
Expand All @@ -31,7 +33,6 @@
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.security.spi.resources.client.ResourceSharingClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;
Expand Down Expand Up @@ -89,7 +90,7 @@ private void preProcessRoleAndPerformSearch(
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));

// If resource-sharing feature is enabled, we fetch accessible model-groups and restrict the search to those model-groups only.
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
// If a model-group is shared, then it will have been shared at-least at read access, hence the final result is guaranteed
// to only contain model-groups that the user at-least has read access to.
addAccessibleModelGroupsFilterAndSearch(tenantId, request, doubleWrappedListener);
Expand All @@ -113,7 +114,7 @@ private void addAccessibleModelGroupsFilterAndSearch(
ActionListener<SearchResponse> wrappedListener
) {
SearchSourceBuilder sourceBuilder = request.source() != null ? request.source() : new SearchSourceBuilder();
ResourceSharingClient rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
var rsc = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
// filter by accessible model-groups
rsc.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(ids -> {
sourceBuilder.query(modelAccessControlHelper.mergeWithAccessFilter(sourceBuilder.query(), ids));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.helper.ModelAccessControlHelper.shouldUseResourceAuthz;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.time.Instant;
Expand Down Expand Up @@ -36,7 +38,6 @@
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction;
Expand Down Expand Up @@ -150,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
)) {
// NOTE all sharing and revoking must happen through share API exposed by security plugin
// client == null -> feature is disabled, follow old route
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() == null) {
if (!shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
// TODO: At some point, this call must be replaced by the one above, (i.e. no user info to
// be stored in model-group index)
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.BACKEND_ROLES_FIELD;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_RESOURCE_TYPE;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED;

import java.util.Collections;
Expand Down Expand Up @@ -57,7 +58,6 @@
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.security.spi.resources.client.ResourceSharingClient;
import org.opensearch.transport.client.Client;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -98,8 +98,8 @@ public void validateModelGroupAccess(User user, String modelGroupId, String acti
listener.onResponse(true);
return;
}
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
var resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> {
if (!isAuthorized) {
listener
Expand Down Expand Up @@ -173,8 +173,8 @@ public void validateModelGroupAccess(
listener.onResponse(true);
return;
}
if (ResourceSharingClientAccessor.getInstance().getResourceSharingClient() != null) {
ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
if (shouldUseResourceAuthz(ML_MODEL_GROUP_RESOURCE_TYPE)) {
var resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, action, ActionListener.wrap(isAuthorized -> {
if (!isAuthorized) {
listener
Expand Down Expand Up @@ -288,6 +288,16 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti
}
}

/**
* Checks whether to utilize new ResourceAuthz
* @param resourceType for which to decide whether to use resource authz
* @return true if the resource-sharing feature is enabled, false otherwise.
*/
public static boolean shouldUseResourceAuthz(String resourceType) {
var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
return client != null && client.isFeatureEnabledForType(resourceType);
}

public boolean skipModelAccessControl(User user) {
// Case 1: user == null when 1. Security is disabled. 2. When user is super-admin
// Case 2: If Security is enabled and filter is disabled, proceed with search as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
Expand Down Expand Up @@ -39,6 +42,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.ResourceSharingClientAccessor;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
Expand All @@ -48,6 +52,7 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.security.spi.resources.client.ResourceSharingClient;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -291,6 +296,110 @@ public void testDeleteModelGroup_Failure() {
assertEquals("Failed to delete data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage());
}

@Test
public void test_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws Exception {
ResourceSharingClient rsc = mock(ResourceSharingClient.class);
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc);

// Feature enabled for this type => apply resource sharing
when(rsc.isFeatureEnabledForType(any())).thenReturn(true);

// Associated models search -> empty => proceed to delete
SearchResponse empty = getEmptySearchResponse();
doAnswer(inv -> {
ActionListener<SearchResponse> l = inv.getArgument(1);
l.onResponse(empty);
return null;
}).when(client).search(any(), isA(ActionListener.class));

// Delete succeeds
doAnswer(inv -> {
ActionListener<DeleteResponse> l = inv.getArgument(1);
l.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener);

// Legacy validation must be skipped
verify(modelAccessControlHelper, never()).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any());

// RSC path still does search (for associated models) then delete
verify(client, times(1)).search(any(), any());
verify(client, times(1)).delete(any(), any());

verify(actionListener, times(1)).onResponse(any(DeleteResponse.class));
}

@Test
public void test_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws Exception {
// Feature enabled globally but TYPE disabled → legacy path
ResourceSharingClient rsc = mock(ResourceSharingClient.class);
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc);

when(rsc.isFeatureEnabledForType(any())).thenReturn(false);

// Associated models search -> empty => proceed to delete
SearchResponse empty = getEmptySearchResponse();
doAnswer(inv -> {
ActionListener<SearchResponse> l = inv.getArgument(1);
l.onResponse(empty);
return null;
}).when(client).search(any(), isA(ActionListener.class));

// Delete succeeds
doAnswer(inv -> {
ActionListener<DeleteResponse> l = inv.getArgument(1);
l.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener);

// Legacy validation must run
verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any());

// Search (models) + delete executed
verify(client, times(1)).search(any(), any());
verify(client, times(1)).delete(any(), any());

verify(actionListener, times(1)).onResponse(any(DeleteResponse.class));
}

@Test
public void test_RSC_FeatureDisabled_UsesLegacyValidation() throws Exception {
// Feature disabled by forcing the gate to false
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null);

// (setup() already stubs validateModelGroupAccess(...)->onResponse(true))

// Associated models search -> empty => proceed to delete
SearchResponse empty = getEmptySearchResponse();
doAnswer(inv -> {
ActionListener<SearchResponse> l = inv.getArgument(1);
l.onResponse(empty);
return null;
}).when(client).search(any(), isA(ActionListener.class));

// Delete succeeds
doAnswer(inv -> {
ActionListener<DeleteResponse> l = inv.getArgument(1);
l.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener);

// Legacy validation must run
verify(modelAccessControlHelper, times(1)).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any());

// Search (models) + delete executed
verify(client, times(1)).search(any(), any());
verify(client, times(1)).delete(any(), any());

verify(actionListener, times(1)).onResponse(any(DeleteResponse.class));
}

private SearchResponse getEmptySearchResponse() {
SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, true, false, null, 1);
Expand Down
Loading
Loading