Skip to content

Commit

Permalink
Fix remote shards balance (#15335)
Browse files Browse the repository at this point in the history
* Fix remote shards balance

Signed-off-by: panguixin <[email protected]>

* add changelog

Signed-off-by: panguixin <[email protected]>

---------

Signed-off-by: panguixin <[email protected]>
Signed-off-by: Andrew Ross <[email protected]>
Co-authored-by: Andrew Ross <[email protected]>
  • Loading branch information
bugmakerrrrrr and andrross authored Dec 13, 2024
1 parent b67cdf4 commit b359dd8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Ensure consistency of system flag on IndexMetadata after diff is applied ([#16644](https://github.com/opensearch-project/OpenSearch/pull/16644))
- Skip remote-repositories validations for node-joins when RepositoriesService is not in sync with cluster-state ([#16763](https://github.com/opensearch-project/OpenSearch/pull/16763))
- Fix _list/shards API failing when closed indices are present ([#16606](https://github.com/opensearch-project/OpenSearch/pull/16606))
- Fix remote shards balance ([#15335](https://github.com/opensearch-project/OpenSearch/pull/15335))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,17 @@ void balance() {
final Map<String, Integer> nodePrimaryShardCount = calculateNodePrimaryShardCount(remoteRoutingNodes);
int totalPrimaryShardCount = nodePrimaryShardCount.values().stream().reduce(0, Integer::sum);

totalPrimaryShardCount += routingNodes.unassigned().getNumPrimaries();
int avgPrimaryPerNode = (totalPrimaryShardCount + routingNodes.size() - 1) / routingNodes.size();
int unassignedRemotePrimaryShardCount = 0;
for (ShardRouting shard : routingNodes.unassigned()) {
if (RoutingPool.REMOTE_CAPABLE.equals(RoutingPool.getShardPool(shard, allocation)) && shard.primary()) {
unassignedRemotePrimaryShardCount++;
}
}
totalPrimaryShardCount += unassignedRemotePrimaryShardCount;
final int avgPrimaryPerNode = (totalPrimaryShardCount + remoteRoutingNodes.size() - 1) / remoteRoutingNodes.size();

ArrayDeque<RoutingNode> sourceNodes = new ArrayDeque<>();
ArrayDeque<RoutingNode> targetNodes = new ArrayDeque<>();
final ArrayDeque<RoutingNode> sourceNodes = new ArrayDeque<>();
final ArrayDeque<RoutingNode> targetNodes = new ArrayDeque<>();
for (RoutingNode node : remoteRoutingNodes) {
if (nodePrimaryShardCount.get(node.nodeId()) > avgPrimaryPerNode) {
sourceNodes.add(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public AllocationService createRemoteCapableAllocationService() {
}

public AllocationService createRemoteCapableAllocationService(String excludeNodes) {
Settings settings = Settings.builder().put("cluster.routing.allocation.exclude.node_id", excludeNodes).build();
Settings settings = Settings.builder().put("cluster.routing.allocation.exclude._id", excludeNodes).build();
return new MockAllocationService(
randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
new TestGatewayAllocator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,51 @@ public class RemoteShardsRebalanceShardsTests extends RemoteShardsBalancerBaseTe
* Post rebalance primaries should be balanced across all the nodes.
*/
public void testShardAllocationAndRebalance() {
int localOnlyNodes = 20;
int remoteCapableNodes = 40;
int localIndices = 40;
int remoteIndices = 80;
final int localOnlyNodes = 20;
final int remoteCapableNodes = 40;
final int halfRemoteCapableNodes = remoteCapableNodes / 2;
final int localIndices = 40;
final int remoteIndices = 80;
ClusterState clusterState = createInitialCluster(localOnlyNodes, remoteCapableNodes, localIndices, remoteIndices);
AllocationService service = this.createRemoteCapableAllocationService();
final StringBuilder excludeNodes = new StringBuilder();
for (int i = 0; i < halfRemoteCapableNodes; i++) {
excludeNodes.append(getNodeId(i, true));
if (i != (remoteCapableNodes / 2 - 1)) {
excludeNodes.append(", ");
}
}
AllocationService service = this.createRemoteCapableAllocationService(excludeNodes.toString());
clusterState = allocateShardsAndBalance(clusterState, service);
RoutingNodes routingNodes = clusterState.getRoutingNodes();
RoutingAllocation allocation = getRoutingAllocation(clusterState, routingNodes);

final Map<String, Integer> nodePrimariesCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, true);
final Map<String, Integer> nodeReplicaCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, false);
Map<String, Integer> nodePrimariesCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, true);
Map<String, Integer> nodeReplicaCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, false);
int avgPrimariesPerNode = getTotalShardCountAcrossNodes(nodePrimariesCounter) / remoteCapableNodes;

// Primary and replica are balanced post first reroute
// Primary and replica are balanced after first allocating unassigned
for (RoutingNode node : routingNodes) {
if (RoutingPool.REMOTE_CAPABLE.equals(RoutingPool.getNodePool(node))) {
if (Integer.parseInt(node.nodeId().split("-")[4]) < halfRemoteCapableNodes) {
assertEquals(0, (int) nodePrimariesCounter.getOrDefault(node.nodeId(), 0));
} else {
assertEquals(avgPrimariesPerNode * 2, (int) nodePrimariesCounter.get(node.nodeId()));
}
assertTrue(nodeReplicaCounter.getOrDefault(node.nodeId(), 0) >= 0);
}
}

// Remove exclude constraint and rebalance
service = this.createRemoteCapableAllocationService();
clusterState = allocateShardsAndBalance(clusterState, service);
routingNodes = clusterState.getRoutingNodes();
allocation = getRoutingAllocation(clusterState, routingNodes);
nodePrimariesCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, true);
nodeReplicaCounter = getShardCounterPerNodeForRemoteCapablePool(clusterState, allocation, false);
for (RoutingNode node : routingNodes) {
if (RoutingPool.REMOTE_CAPABLE.equals(RoutingPool.getNodePool(node))) {
assertInRange(nodePrimariesCounter.get(node.nodeId()), avgPrimariesPerNode, remoteCapableNodes - 1);
assertTrue(nodeReplicaCounter.get(node.nodeId()) >= 0);
assertEquals(avgPrimariesPerNode, (int) nodePrimariesCounter.get(node.nodeId()));
assertTrue(nodeReplicaCounter.getOrDefault(node.nodeId(), 0) >= 0);
}
}
}
Expand Down

0 comments on commit b359dd8

Please sign in to comment.