diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java index e74849c084f48..d502ce0394c63 100644 --- a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java +++ b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskCancellation.java @@ -12,6 +12,8 @@ import org.opensearch.monitor.jvm.JvmStats; import org.opensearch.monitor.process.ProcessProbe; import org.opensearch.search.ResourceType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; import org.opensearch.tasks.TaskCancellation; import org.opensearch.wlm.QueryGroupLevelResourceUsageView; @@ -173,21 +175,52 @@ private boolean shouldCancelTasks(QueryGroup queryGroup, ResourceType resourceTy } private List getTaskCancellations(QueryGroup queryGroup, ResourceType resourceType) { - return defaultTaskSelectionStrategy.selectTasksForCancellation( - queryGroup, - // get the active tasks in the query group + List selectedTasksToCancel = defaultTaskSelectionStrategy.selectTasksForCancellation( queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks(), getReduceBy(queryGroup, resourceType), resourceType ); + List taskCancellations = new ArrayList<>(); + for(Task task : selectedTasksToCancel) { + String cancellationReason = createCancellationReason(queryGroup, task, resourceType); + taskCancellations.add(createTaskCancellation((CancellableTask) task, cancellationReason)); + } + return taskCancellations; + } + + private String createCancellationReason(QueryGroup querygroup, Task task, ResourceType resourceType) { + Double thresholdInPercent = getThresholdInPercent(querygroup, resourceType); + return "[Workload Management] Cancelling Task ID : " + + task.getId() + + " from QueryGroup ID : " + + querygroup.get_id() + + " breached the resource limit of : " + + thresholdInPercent + + " for resource type : " + + resourceType.getName(); + } + + private Double getThresholdInPercent(QueryGroup querygroup, ResourceType resourceType) { + return ((Double) (querygroup.getResourceLimits().get(resourceType))) * 100; + } + + private TaskCancellation createTaskCancellation(CancellableTask task, String cancellationReason) { + return new TaskCancellation(task, List.of(new TaskCancellation.Reason(cancellationReason, 5)), List.of(this::callbackOnCancel)); } protected List getTaskCancellationsForDeletedQueryGroup(QueryGroup queryGroup) { - return defaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup( - queryGroup, - // get the active tasks in the query group + List tasks = defaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup( queryGroupLevelResourceUsageViews.get(queryGroup.get_id()).getActiveTasks() ); + List taskCancellations = new ArrayList<>(); + for(Task task : tasks) { + String cancellationReason = "[Workload Management] Cancelling Task ID : " + + task.getId() + + " from QueryGroup ID : " + + queryGroup.get_id(); + taskCancellations.add(createTaskCancellation((CancellableTask) task, cancellationReason)); + } + return taskCancellations; } private long getReduceBy(QueryGroup queryGroup, ResourceType resourceType) { @@ -229,4 +262,8 @@ private boolean isBreachingThreshold(ResourceType resourceType, Double resourceT // Check if resource usage is breaching the threshold return resourceUsageInMillis > convertThresholdIntoLong(resourceType, resourceThresholdInPercentage); } + + private void callbackOnCancel() { + // TODO Implement callback logic here mostly used for Stats + } } diff --git a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategy.java b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategy.java index a4c4234d37582..124873647c2e5 100644 --- a/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategy.java +++ b/server/src/main/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategy.java @@ -8,7 +8,6 @@ package org.opensearch.wlm.cancellation; -import org.opensearch.cluster.metadata.QueryGroup; import org.opensearch.search.ResourceType; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; @@ -47,8 +46,7 @@ public Comparator sortingCondition() { * @return The list of selected tasks * @throws IllegalArgumentException If the limit is less than zero */ - public List selectTasksForCancellation( - QueryGroup querygroup, + public List selectTasksForCancellation( List tasks, long limit, ResourceType resourceType @@ -62,13 +60,11 @@ public List selectTasksForCancellation( List sortedTasks = tasks.stream().sorted(sortingCondition()).collect(Collectors.toList()); - List selectedTasks = new ArrayList<>(); + List selectedTasks = new ArrayList<>(); long accumulated = 0; - for (Task task : sortedTasks) { if (task instanceof CancellableTask) { - String cancellationReason = createCancellationReason(querygroup, task, resourceType); - selectedTasks.add(createTaskCancellation((CancellableTask) task, cancellationReason)); + selectedTasks.add(task); accumulated += resourceType.getResourceUsage(task); if (accumulated >= limit) { break; @@ -84,46 +80,13 @@ public List selectTasksForCancellation( * {@link CancellableTask}. For each selected task, it creates a cancellation reason and adds * a {@link TaskCancellation} object to the list of selected tasks. * - * @param querygroup The {@link QueryGroup} from which the tasks are being selected. * @param tasks The list of {@link Task} objects to be evaluated for cancellation. * @return A list of {@link TaskCancellation} objects representing the tasks selected for cancellation. */ - public List selectTasksFromDeletedQueryGroup(QueryGroup querygroup, List tasks) { - List selectedTasks = new ArrayList<>(); - - for (Task task : tasks) { - if (task instanceof CancellableTask) { - String cancellationReason = "[Workload Management] Cancelling Task ID : " - + task.getId() - + " from QueryGroup ID : " - + querygroup.get_id(); - selectedTasks.add(createTaskCancellation((CancellableTask) task, cancellationReason)); - } - } - return selectedTasks; - } - - private String createCancellationReason(QueryGroup querygroup, Task task, ResourceType resourceType) { - Double thresholdInPercent = getThresholdInPercent(querygroup, resourceType); - return "[Workload Management] Cancelling Task ID : " - + task.getId() - + " from QueryGroup ID : " - + querygroup.get_id() - + " breached the resource limit of : " - + thresholdInPercent - + " for resource type : " - + resourceType.getName(); - } - - private Double getThresholdInPercent(QueryGroup querygroup, ResourceType resourceType) { - return ((Double) (querygroup.getResourceLimits().get(resourceType))) * 100; - } - - private TaskCancellation createTaskCancellation(CancellableTask task, String cancellationReason) { - return new TaskCancellation(task, List.of(new TaskCancellation.Reason(cancellationReason, 5)), List.of(this::callbackOnCancel)); - } - - private void callbackOnCancel() { - // TODO Implement callback logic here mostly used for Stats + public List selectTasksFromDeletedQueryGroup(List tasks) { + return tasks + .stream() + .filter(task -> task instanceof CancellableTask) + .collect(Collectors.toList()); } } diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java index f8e5e83becadc..a455478b27116 100644 --- a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java +++ b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskCancellationTests.java @@ -19,6 +19,7 @@ import org.opensearch.wlm.QueryGroupLevelResourceUsageView; import org.junit.Before; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -26,6 +27,7 @@ import java.util.Map; import java.util.Set; import java.util.function.BooleanSupplier; +import java.util.stream.Collectors; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -66,6 +68,35 @@ public void setup() { ); } + public void testGetCancellableTasksFrom_setupAppropriateCancellationReasonAndScore() { + ResourceType resourceType = ResourceType.CPU; + long usage = 100_000_000L; + Double threshold = 0.1; + + QueryGroup queryGroup1 = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + QueryGroupLevelResourceUsageView mockView = createResourceUsageViewMock(resourceType, usage); + queryGroupLevelViews.put(queryGroupId1, mockView); + + List cancellableTasksFrom = taskCancellation.getCancellableTasksFrom(queryGroup1); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + assertEquals( + "[Workload Management] Cancelling Task ID : " + + cancellableTasksFrom.get(0).getTask().getId() + + " from QueryGroup ID : queryGroup1" + + " breached the resource limit of : 10.0 for resource type : cpu", + cancellableTasksFrom.get(0).getReasonString() + ); + assertEquals(5, cancellableTasksFrom.get(0).getReasons().get(0).getCancellationScore()); + } + public void testGetCancellableTasksFrom_returnsTasksWhenBreachingThreshold() { ResourceType resourceType = ResourceType.CPU; long usage = 100_000_000L; @@ -216,8 +247,7 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() { ); QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(resourceType, usage); - QueryGroupLevelResourceUsageView mockView2 = mock(QueryGroupLevelResourceUsageView.class); - when(mockView2.getActiveTasks()).thenReturn(List.of(getRandomSearchTask(1000), getRandomSearchTask(1001))); + QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock(resourceType, usage, List.of(1000, 1001)); queryGroupLevelViews.put(queryGroupId1, mockView1); queryGroupLevelViews.put(queryGroupId2, mockView2); activeQueryGroups.add(activeQueryGroup); @@ -228,7 +258,7 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() { queryGroupLevelViews, activeQueryGroups, deletedQueryGroups, - () -> false + () -> true ); List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(QueryGroup.ResiliencyMode.ENFORCED); @@ -251,6 +281,62 @@ public void testCancelTasks_cancelsTasksFromDeletedQueryGroups() { assertTrue(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled()); } + public void testCancelTasks_does_not_cancelTasksFromDeletedQueryGroups_whenNodeNotInDuress() { + ResourceType resourceType = ResourceType.CPU; + long usage = 150_000_000_000L; + Double threshold = 0.01; + + QueryGroup activeQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId1, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroup deletedQueryGroup = new QueryGroup( + "testQueryGroup", + queryGroupId2, + QueryGroup.ResiliencyMode.ENFORCED, + Map.of(resourceType, threshold), + 1L + ); + + QueryGroupLevelResourceUsageView mockView1 = createResourceUsageViewMock(resourceType, usage); + QueryGroupLevelResourceUsageView mockView2 = createResourceUsageViewMock(resourceType, usage, List.of(1000, 1001)); + queryGroupLevelViews.put(queryGroupId1, mockView1); + queryGroupLevelViews.put(queryGroupId2, mockView2); + activeQueryGroups.add(activeQueryGroup); + deletedQueryGroups.add(deletedQueryGroup); + + TestTaskCancellationImpl taskCancellation = new TestTaskCancellationImpl( + new DefaultTaskSelectionStrategy(), + queryGroupLevelViews, + activeQueryGroups, + deletedQueryGroups, + () -> false + ); + + List cancellableTasksFrom = taskCancellation.getAllCancellableTasks(QueryGroup.ResiliencyMode.ENFORCED); + assertEquals(2, cancellableTasksFrom.size()); + assertEquals(1234, cancellableTasksFrom.get(0).getTask().getId()); + assertEquals(4321, cancellableTasksFrom.get(1).getTask().getId()); + + List cancellableTasksFromDeletedQueryGroups = taskCancellation.getTaskCancellationsForDeletedQueryGroup( + deletedQueryGroup + ); + assertEquals(2, cancellableTasksFromDeletedQueryGroups.size()); + assertEquals(1000, cancellableTasksFromDeletedQueryGroups.get(0).getTask().getId()); + assertEquals(1001, cancellableTasksFromDeletedQueryGroups.get(1).getTask().getId()); + + taskCancellation.cancelTasks(); + + assertTrue(cancellableTasksFrom.get(0).getTask().isCancelled()); + assertTrue(cancellableTasksFrom.get(1).getTask().isCancelled()); + assertFalse(cancellableTasksFromDeletedQueryGroups.get(0).getTask().isCancelled()); + assertFalse(cancellableTasksFromDeletedQueryGroups.get(1).getTask().isCancelled()); + } + public void testCancelTasks_cancelsGivenTasks_WhenNodeInDuress() { ResourceType resourceType = ResourceType.CPU; long usage = 150_000_000_000L; @@ -384,6 +470,21 @@ private QueryGroupLevelResourceUsageView createResourceUsageViewMock(ResourceTyp return mockView; } + private QueryGroupLevelResourceUsageView createResourceUsageViewMock( + ResourceType resourceType, + Long usage, + Collection ids + ) { + QueryGroupLevelResourceUsageView mockView = mock(QueryGroupLevelResourceUsageView.class); + when(mockView.getResourceUsageData()).thenReturn(Collections.singletonMap(resourceType, usage)); + when(mockView.getActiveTasks()).thenReturn( + ids.stream() + .map(this::getRandomSearchTask) + .collect(Collectors.toList()) + ); + return mockView; + } + private Task getRandomSearchTask(long id) { return new SearchTask( id, diff --git a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategyTests.java b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategyTests.java index 361f52a3b2e38..9649a5dea0bb7 100644 --- a/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategyTests.java +++ b/server/src/test/java/org/opensearch/wlm/cancellation/DefaultTaskSelectionStrategyTests.java @@ -29,82 +29,37 @@ public class DefaultTaskSelectionStrategyTests extends OpenSearchTestCase { public void testSelectTasksFromDeletedQueryGroup() { DefaultTaskSelectionStrategy testDefaultTaskSelectionStrategy = new DefaultTaskSelectionStrategy(); - long thresholdInLong = 100L; - Double threshold = Double.MIN_VALUE; long reduceBy = Long.MIN_VALUE; - ResourceType resourceType = ResourceType.MEMORY; List tasks = getListOfTasks(thresholdInLong); - - QueryGroup queryGroup = new QueryGroup( - "testQueryGroup", - "queryGroupId1", - QueryGroup.ResiliencyMode.ENFORCED, - Map.of(resourceType, threshold), - 1L - ); - - List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup(queryGroup, tasks); - + List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksFromDeletedQueryGroup(tasks); assertFalse(selectedTasks.isEmpty()); - assertEquals( - "[Workload Management] Cancelling Task ID : " + selectedTasks.get(0).getTask().getId() + " from QueryGroup ID : queryGroupId1", - selectedTasks.get(0).getReasonString() - ); - assertEquals(5, selectedTasks.get(0).getReasons().get(0).getCancellationScore()); assertTrue(tasksUsageMeetsThreshold(selectedTasks, reduceBy)); } public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsGreaterThanZero() { DefaultTaskSelectionStrategy testDefaultTaskSelectionStrategy = new DefaultTaskSelectionStrategy(); long thresholdInLong = 100L; - Double threshold = 0.1; long reduceBy = 50L; ResourceType resourceType = ResourceType.MEMORY; List tasks = getListOfTasks(thresholdInLong); - - QueryGroup queryGroup = new QueryGroup( - "testQueryGroup", - "queryGroupId1", - QueryGroup.ResiliencyMode.ENFORCED, - Map.of(resourceType, threshold), - 1L - ); - - List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksForCancellation( - queryGroup, + List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksForCancellation( tasks, reduceBy, resourceType ); assertFalse(selectedTasks.isEmpty()); - assertEquals( - "[Workload Management] Cancelling Task ID : " - + selectedTasks.get(0).getTask().getId() - + " from QueryGroup ID : queryGroupId1 breached the resource limit of : 10.0 for resource type : memory", - selectedTasks.get(0).getReasonString() - ); - assertEquals(5, selectedTasks.get(0).getReasons().get(0).getCancellationScore()); assertTrue(tasksUsageMeetsThreshold(selectedTasks, reduceBy)); } public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsLesserThanZero() { DefaultTaskSelectionStrategy testDefaultTaskSelectionStrategy = new DefaultTaskSelectionStrategy(); long thresholdInLong = 100L; - Double threshold = 0.1; long reduceBy = -50L; ResourceType resourceType = ResourceType.MEMORY; List tasks = getListOfTasks(thresholdInLong); - QueryGroup queryGroup = new QueryGroup( - "testQueryGroup", - "queryGroupId1", - QueryGroup.ResiliencyMode.ENFORCED, - Map.of(resourceType, threshold), - 1L - ); - try { - testDefaultTaskSelectionStrategy.selectTasksForCancellation(queryGroup, tasks, reduceBy, resourceType); + testDefaultTaskSelectionStrategy.selectTasksForCancellation(tasks, reduceBy, resourceType); } catch (Exception e) { assertTrue(e instanceof IllegalArgumentException); assertEquals("limit has to be greater than zero", e.getMessage()); @@ -114,20 +69,10 @@ public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsLess public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsEqualToZero() { DefaultTaskSelectionStrategy testDefaultTaskSelectionStrategy = new DefaultTaskSelectionStrategy(); long thresholdInLong = 100L; - Double threshold = 0.1; long reduceBy = 0; ResourceType resourceType = ResourceType.MEMORY; List tasks = getListOfTasks(thresholdInLong); - QueryGroup queryGroup = new QueryGroup( - "testQueryGroup", - "queryGroupId1", - QueryGroup.ResiliencyMode.ENFORCED, - Map.of(resourceType, threshold), - 1L - ); - - List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksForCancellation( - queryGroup, + List selectedTasks = testDefaultTaskSelectionStrategy.selectTasksForCancellation( tasks, reduceBy, resourceType @@ -135,10 +80,10 @@ public void testSelectTasksToCancelSelectsTasksMeetingThreshold_ifReduceByIsEqua assertTrue(selectedTasks.isEmpty()); } - private boolean tasksUsageMeetsThreshold(List selectedTasks, long threshold) { + private boolean tasksUsageMeetsThreshold(List selectedTasks, long threshold) { long memory = 0; - for (TaskCancellation task : selectedTasks) { - memory += task.getTask().getTotalResourceUtilization(ResourceStats.MEMORY); + for (Task task : selectedTasks) { + memory += task.getTotalResourceUtilization(ResourceStats.MEMORY); if (memory > threshold) { return true; }