Skip to content

Commit

Permalink
enhancement: wait model undeploy before delete; refactor the wait res…
Browse files Browse the repository at this point in the history
…ponse logic

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Jan 30, 2024
1 parent 893a4a2 commit 0331a81
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.apache.commons.lang3.StringUtils;
Expand All @@ -35,6 +36,7 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down Expand Up @@ -124,27 +126,35 @@ protected String indexMonitor(String monitorAsJsonString) {
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
protected Map<String, Object> waitResponseMeetingCondition(
String method,
String endpoint,
String jsonEntity,
Predicate<Map<String, Object>> condition
) {
for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) {
Response response = makeRequest(client(), "GET", "/_plugins/_ml/tasks/" + taskId, null, (String) null, null);
Response response = makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
Map<String, Object> responseInMap = parseResponseToMap(response);
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
if (state.equals(MLTaskState.COMPLETED.toString())) {
if (condition.test(responseInMap)) {
return responseInMap;
}
if (state.equals(MLTaskState.FAILED.toString())
|| state.equals(MLTaskState.CANCELLED.toString())
|| state.equals(MLTaskState.COMPLETED_WITH_ERROR.toString())) {
logger.info("Get task response: " + responseInMap.toString());
fail("The task failed with state " + state);
}
logger.info("The " + i + "-th response: " + responseInMap.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail("The task failed to complete after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds.");
return null;
}

@SneakyThrows
protected Map<String, Object> waitTaskComplete(String taskId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLTask.STATE_FIELD).toString();
return state.equals(MLTaskState.COMPLETED.toString());
};
return waitResponseMeetingCondition("GET", "/_plugins/_ml/tasks/" + taskId, (String) null, condition);
}

// Register the model then deploy it. Returns the model_id until the model is deployed
protected String registerModelThenDeploy(String requestBody) {
String registerModelTaskId = registerModel(requestBody);
Expand All @@ -155,12 +165,23 @@ protected String registerModelThenDeploy(String requestBody) {
return modelId;
}

@SneakyThrows
private void waitModelUndeployed(String modelId) {
Predicate<Map<String, Object>> condition = responseInMap -> {
String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString();
return !state.equals(MLModelState.DEPLOYED.toString())
&& !state.equals(MLModelState.DEPLOYING.toString())
&& !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString());
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, (String) null, condition);
return;
}

@SneakyThrows
protected void deleteModel(String modelId) {
// need to undeploy first as model can be in use
makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);
// wait ml-commons CronJob update model status.
Thread.sleep(5000);
waitModelUndeployed(modelId);
makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + modelId, null, (String) null, null);
}

Expand Down

0 comments on commit 0331a81

Please sign in to comment.