Skip to content

Commit

Permalink
feat: Implement fallback strategy in tier balancer #588 (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay authored Nov 29, 2024
1 parent b472238 commit ff1c163
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
@NoArgsConstructor
public class Upstream {

public static final int ERROR_THRESHOLD = 3;

private String endpoint;
private String key;
@JsonDeserialize(using = JsonToStringDeserializer.class)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,57 +1,87 @@
package com.epam.aidial.core.server.upstream;

import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Route;
import com.epam.aidial.core.config.Upstream;
import lombok.Getter;
import lombok.Setter;
import com.epam.aidial.core.storage.http.HttpStatus;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
* Tiered load balancer. Each next() call returns an available upstream from the highest tier (lowest tier value in config).
* If the whole tier (highest) is unavailable, balancer start routing upstreams from next tier (lower) if any.
*/
class TieredBalancer implements LoadBalancer<UpstreamState> {
class TieredBalancer {

@Getter
private final List<Upstream> originalUpstreams;
private final List<WeightedRoundRobinBalancer> tiers;

@Getter
@Setter
private long lastAccessTime;
private final List<UpstreamState> upstreamStates = new ArrayList<>();

/**
* Note. The value is taken from {@link Deployment#getMaxRetryAttempts()} or {@link Route#getMaxRetryAttempts()}
*/
@Getter
private final int originalMaxRetryAttempts;
private final List<Predicate<UpstreamState>> predicates = new ArrayList<>();

public TieredBalancer(String deploymentName, List<Upstream> upstreams, int originalMaxRetryAttempts) {
this.originalUpstreams = upstreams;
public TieredBalancer(String deploymentName, List<Upstream> upstreams) {
this.tiers = buildTiers(deploymentName, upstreams);
this.originalMaxRetryAttempts = originalMaxRetryAttempts;
for (WeightedRoundRobinBalancer tier : tiers) {
upstreamStates.addAll(tier.getUpstreams());
}
predicates.add(state -> state.getStatus().is5xx()
&& state.getSource() == UpstreamState.RetryAfterSource.CORE);
predicates.add(state -> state.getStatus().is5xx()
&& state.getSource() == UpstreamState.RetryAfterSource.UPSTREAM);
predicates.add(state -> state.getStatus() == HttpStatus.TOO_MANY_REQUESTS
&& state.getSource() == UpstreamState.RetryAfterSource.CORE);
predicates.add(state -> state.getStatus() == HttpStatus.TOO_MANY_REQUESTS
&& state.getSource() == UpstreamState.RetryAfterSource.UPSTREAM);
}

@Nullable
@Override
public UpstreamState next() {
synchronized Upstream next(Set<Upstream> usedUpstreams) {
for (WeightedRoundRobinBalancer tier : tiers) {
UpstreamState upstreamState = tier.next();
if (upstreamState != null) {
return upstreamState;
return upstreamState.getUpstream();
}
}
// fallback
for (Predicate<UpstreamState> p : predicates) {
UpstreamState candidate = upstreamStates.stream().filter(p)
.filter(upstreamState -> !usedUpstreams.contains(upstreamState.getUpstream()))
.min(Comparator.comparingLong(UpstreamState::getRetryAfter)).orElse(null);
if (candidate != null) {
usedUpstreams.add(candidate.getUpstream());
return candidate.getUpstream();
}
}

return null;
}

synchronized void fail(Upstream upstream, HttpStatus status, long retryAfterSeconds) {
Objects.requireNonNull(upstream);
UpstreamState upstreamState = findUpstreamState(upstream);
upstreamState.fail(status, retryAfterSeconds);
}

synchronized void succeed(Upstream upstream) {
Objects.requireNonNull(upstream);
UpstreamState upstreamState = findUpstreamState(upstream);
upstreamState.succeeded();
}

private UpstreamState findUpstreamState(Upstream upstream) {
for (UpstreamState upstreamState : upstreamStates) {
if (upstreamState.getUpstream().equals(upstream)) {
return upstreamState;
}
}
throw new IllegalArgumentException("Upstream is not found: " + upstream);
}

private static List<WeightedRoundRobinBalancer> buildTiers(String deploymentName, List<Upstream> upstreams) {
List<WeightedRoundRobinBalancer> balancers = new ArrayList<>();
Map<Integer, List<Upstream>> groups = upstreams.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;

/**
Expand All @@ -28,30 +31,30 @@
@Slf4j
public class UpstreamRoute {

private static final long DEFAULT_RETRY_AFTER_SECONDS_VALUE = 30;

private final LoadBalancer<UpstreamState> balancer;
private final TieredBalancer balancer;
/**
* The maximum number of attempts the route may retry
*/
private final int maxRetryAttempts;

/**
* Current upstream state
* Current upstream
*/
@Nullable
private UpstreamState upstreamState;
private Upstream upstream;
/**
* Attempt counter
*/
@Getter
private int attemptCount;

public UpstreamRoute(LoadBalancer<UpstreamState> balancer, int maxRetryAttempts) {
private final Set<Upstream> usedUpstreams = new HashSet<>();

public UpstreamRoute(TieredBalancer balancer, int maxRetryAttempts) {
this.balancer = balancer;
this.maxRetryAttempts = maxRetryAttempts;
this.upstreamState = balancer.next();
this.attemptCount = upstreamState == null ? 0 : 1;
this.upstream = balancer.next(usedUpstreams);
this.attemptCount = upstream == null ? 0 : 1;
}

/**
Expand All @@ -60,7 +63,7 @@ public UpstreamRoute(LoadBalancer<UpstreamState> balancer, int maxRetryAttempts)
* @return true if upstream available, false otherwise
*/
public boolean available() {
return upstreamState != null && attemptCount <= maxRetryAttempts;
return upstream != null && attemptCount <= maxRetryAttempts;
}

/**
Expand All @@ -72,53 +75,57 @@ public boolean available() {
public Upstream next() {
// if max attempts reached - do not call balancer
if (attemptCount + 1 > maxRetryAttempts) {
this.upstreamState = null;
this.upstream = null;
return null;
}
attemptCount++;
this.upstreamState = balancer.next();
return upstreamState == null ? null : upstreamState.getUpstream();
this.upstream = balancer.next(usedUpstreams);
return upstream;
}

/**
* @return get current upstream. null if no upstream available
*/
@Nullable
public Upstream get() {
return upstreamState == null ? null : upstreamState.getUpstream();
return upstream;
}

public void fail(HttpStatus status) {
fail(status, -1);
}

public void fail(HttpClientResponse response) {
long retryAfter = retrieveRetryAfterSeconds(response);
HttpStatus status = HttpStatus.fromStatusCode(response.statusCode());
fail(status, retryAfter);
}

/**
* Fail current upstream due to error
*
* @param status - response http status; typically, 5xx or 429
* @param retryAfterSeconds - the amount of seconds after which upstream should be available; if status 5xx this value ignored
* @param retryAfterSeconds - the amount of seconds after which upstream should be available
*/
public void fail(HttpStatus status, long retryAfterSeconds) {
if (upstreamState != null) {
upstreamState.failed(status, retryAfterSeconds);
}
}

public void fail(HttpStatus status) {
fail(status, DEFAULT_RETRY_AFTER_SECONDS_VALUE);
void fail(HttpStatus status, long retryAfterSeconds) {
verifyCurrentUpstream();
balancer.fail(upstream, status, retryAfterSeconds);
}

public void fail(HttpClientResponse response) {
fail(HttpStatus.fromStatusCode(response.statusCode()), calculateRetryAfterSeconds(response));
public void succeed() {
verifyCurrentUpstream();
balancer.succeed(upstream);
}

public void succeed() {
if (upstreamState != null) {
upstreamState.succeeded();
}
private void verifyCurrentUpstream() {
Objects.requireNonNull(upstream, "current upstream is undefined");
}

/**
* @param response http response from upstream
* @return the amount of seconds after which upstream should be available
* @return the amount of seconds after which upstream should be available or -1 if the value is not provided
*/
private static long calculateRetryAfterSeconds(HttpClientResponse response) {
private static long retrieveRetryAfterSeconds(HttpClientResponse response) {
try {
String retryAfterHeaderValue = response.getHeader("Retry-After");
if (retryAfterHeaderValue != null) {
Expand All @@ -127,9 +134,9 @@ private static long calculateRetryAfterSeconds(HttpClientResponse response) {
}
log.debug("Retry-after header not found, status code {}", response.statusCode());
} catch (Exception e) {
log.warn("Failed to parse Retry-After header value, fallback to the default value: " + DEFAULT_RETRY_AFTER_SECONDS_VALUE, e);
log.warn("Failed to parse Retry-After header value, fallback to the default value: " + UpstreamState.DEFAULT_RETRY_AFTER_SECONDS_VALUE, e);
}

return DEFAULT_RETRY_AFTER_SECONDS_VALUE;
return -1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class UpstreamRouteProvider {
/**
* Cached load balancers
*/
private final ConcurrentHashMap<String, TieredBalancer> balancers = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, BalancerWrapper> balancers = new ConcurrentHashMap<>();

public UpstreamRouteProvider(Vertx vertx) {
vertx.setPeriodic(0, TimeUnit.MINUTES.toMillis(1), event -> evictExpiredBalancers());
Expand All @@ -49,19 +49,19 @@ public UpstreamRoute get(Route route) {
}

private UpstreamRoute get(String key, List<Upstream> upstreams, int maxRetryAttempts) {
TieredBalancer balancer = balancers.compute(key, (k, cur) -> {
TieredBalancer result;
if (cur != null && isUpstreamsTheSame(cur.getOriginalUpstreams(), upstreams)
&& maxRetryAttempts == cur.getOriginalMaxRetryAttempts()) {
BalancerWrapper wrapper = balancers.compute(key, (k, cur) -> {
BalancerWrapper result;
if (cur != null && isUpstreamsTheSame(cur.upstreams, upstreams)
&& maxRetryAttempts == cur.maxRetryAttempts) {
result = cur;
} else {
result = new TieredBalancer(key, upstreams, maxRetryAttempts);
result = new BalancerWrapper(key, maxRetryAttempts, upstreams);
}
result.setLastAccessTime(System.currentTimeMillis());
result.lastAccessTime = System.currentTimeMillis();
return result;
});
int result = Math.min(maxRetryAttempts, upstreams.size());
return new UpstreamRoute(balancer, result);
return new UpstreamRoute(wrapper.balancer, result);
}

private List<Upstream> getUpstreams(Deployment deployment) {
Expand Down Expand Up @@ -98,16 +98,34 @@ private String getKey(Deployment deployment) {
private void evictExpiredBalancers() {
long currentTime = System.currentTimeMillis();
for (String key : balancers.keySet()) {
balancers.compute(key, (k, balancer) -> {
if (balancer != null && currentTime - balancer.getLastAccessTime() > IDLE_PERIOD_IN_MS) {
balancers.compute(key, (k, wrapper) -> {
if (wrapper != null && currentTime - wrapper.lastAccessTime > IDLE_PERIOD_IN_MS) {
return null;
}
return balancer;
return wrapper;
});
}
}

private static boolean isUpstreamsTheSame(List<Upstream> a, List<Upstream> b) {
return new HashSet<>(a).equals(new HashSet<>(b));
}

private static class BalancerWrapper {
final TieredBalancer balancer;
long lastAccessTime;

/**
* Note. The value is taken from {@link Deployment#getMaxRetryAttempts()} or {@link Route#getMaxRetryAttempts()}
*/
final int maxRetryAttempts;

final List<Upstream> upstreams;

public BalancerWrapper(String key, int maxRetryAttempts, List<Upstream> upstreams) {
this.balancer = new TieredBalancer(key, upstreams);
this.maxRetryAttempts = maxRetryAttempts;
this.upstreams = upstreams;
}
}
}
Loading

0 comments on commit ff1c163

Please sign in to comment.