Skip to content

Commit

Permalink
code review comments and support for token refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
v1r3n committed Jun 28, 2023
1 parent 44b0d89 commit 6d7a128
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 81 deletions.
141 changes: 71 additions & 70 deletions src/main/java/io/orkes/conductor/client/ApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
*/
package io.orkes.conductor.client;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.squareup.okhttp.*;
import com.squareup.okhttp.internal.http.HttpMethod;
import io.orkes.conductor.client.http.*;
import io.orkes.conductor.client.http.api.TokenResourceApi;
import io.orkes.conductor.client.http.auth.ApiKeyAuth;
import io.orkes.conductor.client.http.auth.Authentication;
import io.orkes.conductor.client.model.GenerateTokenRequest;
import okio.BufferedSink;
import okio.Okio;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threeten.bp.LocalDate;
import org.threeten.bp.OffsetDateTime;
import org.threeten.bp.format.DateTimeFormatter;

import javax.net.ssl.*;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
Expand All @@ -33,45 +52,23 @@
import java.util.*;
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.net.ssl.*;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threeten.bp.LocalDate;
import org.threeten.bp.OffsetDateTime;
import org.threeten.bp.format.DateTimeFormatter;

import io.orkes.conductor.client.http.*;
import io.orkes.conductor.client.http.api.TokenResourceApi;
import io.orkes.conductor.client.http.auth.ApiKeyAuth;
import io.orkes.conductor.client.http.auth.Authentication;
import io.orkes.conductor.client.model.GenerateTokenRequest;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.squareup.okhttp.*;
import com.squareup.okhttp.internal.http.HttpMethod;
import okio.BufferedSink;
import okio.Okio;

public class ApiClient {
private static final Logger LOGGER = LoggerFactory.getLogger(ApiClient.class);

private static final String TOKEN_CACHE_KEY = "TOKEN";
private final Cache<String, String> tokenCache;
private Cache<String, String> tokenCache;

private final String basePath;
private final Map<String, String> defaultHeaderMap = new HashMap<String, String>();

private String tempFolderPath;

private Map<String, Authentication> authentications;

private InputStream sslCaCert;
private boolean verifyingSsl;
private KeyManager[] keyManagers;
Expand All @@ -91,51 +88,66 @@ public class ApiClient {

private int executorThreadCount = 0;

private long tokenRefreshInSeconds = 2700; //45 minutes

private ScheduledExecutorService tokenRefreshService;
/*
* Constructor for ApiClient
*/

public ApiClient() {
this("http://localhost:8080/api");
this("http://localhost:8080/api", null, null);
}

public ApiClient(String basePath) {
this.tokenCache = CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.MINUTES).build();
if(basePath.endsWith("/")) {
basePath = basePath.substring(0, basePath.length()-1);
}
this.basePath = basePath;
httpClient = new OkHttpClient();
httpClient.setRetryOnConnectionFailure(true);
verifyingSsl = true;
json = new JSON();
authentications = new HashMap<>();
this(basePath, null, null);
}

public ApiClient(String basePath, SecretsManager secretsManager, String keyPath, String secretPath) {
this(basePath);
try {
keyId = secretsManager.getSecret(keyPath);
keySecret = secretsManager.getSecret(secretPath);
getToken();
} catch (Throwable t) {
LOGGER.error(t.getMessage(), t);
}
this(basePath, secretsManager.getSecret(keyPath), secretsManager.getSecret(secretPath));
}

public ApiClient(String basePath, String keyId, String keySecret) {
this(basePath);
if(basePath.endsWith("/")) {
basePath = basePath.substring(0, basePath.length()-1);
}
this.basePath = basePath;

this.keyId = keyId;
this.keySecret = keySecret;
try {
getToken();
} catch (Throwable t) {
LOGGER.error(t.getMessage(), t);
this.httpClient = new OkHttpClient();
this.httpClient.setRetryOnConnectionFailure(true);
this.verifyingSsl = true;
this.json = new JSON();
this.tokenCache = CacheBuilder.newBuilder().expireAfterWrite(tokenRefreshInSeconds, TimeUnit.SECONDS).build();
if(useSecurity()) {
scheduleTokenRefresh();
try {
//This should be in the try catch so if the client is initialized and if the server is down or not reachable
//Client will still initialize without errors
getToken();
} catch (Throwable t) {
LOGGER.error(t.getMessage(), t);
}
}
}

public void setTokenRefreshTime(long duration, TimeUnit timeUnit) {
this.tokenRefreshInSeconds = timeUnit.toSeconds(duration);
Cache<String, String> tokenCacheNew = CacheBuilder.newBuilder().expireAfterWrite(tokenRefreshInSeconds, TimeUnit.SECONDS).build();
synchronized (tokenCache) {
tokenCache = tokenCacheNew;
tokenRefreshService.shutdownNow();
scheduleTokenRefresh();
}
}

public ApiClient(String basePath, String token) {
this(basePath);
private void scheduleTokenRefresh() {
this.tokenRefreshService = Executors.newSingleThreadScheduledExecutor();
long refreshInterval = Math.max(30, tokenRefreshInSeconds - 30);
this.tokenRefreshService.scheduleAtFixedRate(()-> {
refreshToken();
}, refreshInterval,refreshInterval, TimeUnit.SECONDS);
}

public boolean useSecurity() {
Expand Down Expand Up @@ -338,13 +350,7 @@ public ApiClient setLenientOnJson(boolean lenientOnJson) {
* @param apiKey API key
*/
public void setApiKey(String apiKey) {
for (Authentication auth : authentications.values()) {
if (auth instanceof ApiKeyAuth) {
((ApiKeyAuth) auth).setApiKey(apiKey);
return;
}
}
throw new RuntimeException("No API key authentication configured!");
this.keyId = apiKey;
}

/**
Expand Down Expand Up @@ -1096,15 +1102,10 @@ public void processHeaderParams(Map<String, String> headerParams, Request.Builde
* @param headerParams Map of header parameters
*/
public void updateParamsForAuth(String[] authNames, List<Pair> queryParams, Map<String, String> headerParams) {
if(useSecurity() && authentications.isEmpty()) {
LOGGER.debug("No authentication set, will refresh token");
refreshToken();
}
for (String authName : authNames) {
Authentication auth = authentications.get(authName);
if (auth != null) {
auth.applyToParams(queryParams, headerParams);
}
String token = getToken();
if(useSecurity()) {
Authentication auth = getApiKeyHeader(token);
auth.applyToParams(queryParams, headerParams);
}
}

Expand Down Expand Up @@ -1259,19 +1260,19 @@ public String getToken() {
}

private String refreshToken() {
LOGGER.debug("Refreshing token @ {}", new Date());
if (this.keyId == null || this.keySecret == null) {
throw new RuntimeException("KeyId and KeySecret must be set in order to get an authentication token");
}
GenerateTokenRequest generateTokenRequest = new GenerateTokenRequest().keyId(this.keyId).keySecret(this.keySecret);
Map<String, String> response = TokenResourceApi.generateTokenWithHttpInfo(this, generateTokenRequest).getData();
String token = response.get("token");
this.setApiKeyHeader(token);
return token;
}

private synchronized void setApiKeyHeader(String token) {
private ApiKeyAuth getApiKeyHeader(String token) {
ApiKeyAuth apiKeyAuth = new ApiKeyAuth("header", "X-Authorization");
apiKeyAuth.setApiKey(token);
authentications.put("api_key", apiKeyAuth);
return apiKeyAuth;
}
}
25 changes: 16 additions & 9 deletions src/main/java/io/orkes/conductor/client/automator/TaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,17 @@ class TaskRunner {
this.domain = Optional.ofNullable(PropertyFactory.getString(taskType, DOMAIN, null))
.orElseGet(() -> Optional.ofNullable(PropertyFactory.getString(ALL_WORKERS, DOMAIN, null))
.orElse(taskToDomain.get(taskType)));
this.errorAt = PropertyFactory.getInteger(taskType, DOMAIN, 100);

int defaultLoggingInterval = 100;
int errorInterval = PropertyFactory.getInteger(taskType, "LOG_INTERVAL", 0);
if(errorInterval == 0) {
errorInterval = PropertyFactory.getInteger(ALL_WORKERS, "LOG_INTERVAL", 0);
}
if(errorInterval == 0) {
errorInterval = defaultLoggingInterval;
}
this.errorAt = errorInterval;
LOGGER.info("Polling errors will be sampled at every {} error (after the first 100 errors) for taskType {}", this.errorAt, taskType);
this.executorService =
(ThreadPoolExecutor)
Executors.newFixedThreadPool(
Expand Down Expand Up @@ -192,23 +201,21 @@ private List<Task> pollTasksForWorker() {
LOGGER.debug("Time taken to poll {} task with a batch size of {} is {} ms", taskType, tasks.size(), stopwatch.elapsed(TimeUnit.MILLISECONDS));

} catch (Throwable e) {
//For the first N (errorAt) errors, just print them as is...
permits.release(pollCount - tasks.size());

//For the first 100 errors, just print them as is...
boolean printError = false;
if(pollingErrorCount < errorAt) {
printError = true;
} else if (pollingErrorCount % errorAt == 0) {
if(pollingErrorCount < 100 || pollingErrorCount % errorAt == 0) {
printError = true;
}
pollingErrorCount++;
if(pollingErrorCount > 1_000_000) {
//Reset after 1 million errors
if(pollingErrorCount > 10_000_000) {
//Reset after 10 million errors
pollingErrorCount = 0;
}
if(printError) {
LOGGER.error("Error polling for taskType: {}, error = {}", taskType, e.getMessage(), e);
}
pollingErrorCount++;
permits.release(pollCount - tasks.size());
}
return tasks;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public ApiClient getApiClient() {
return apiClient;
}

/**
*
* @return ObjectMapper used to serialize objects - can be modified to add additional modules.
*/
public ObjectMapper getObjectMapper() {
return objectMapper;
}
Expand Down
2 changes: 0 additions & 2 deletions src/test/java/io/orkes/conductor/client/LocalWorkerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ public class LocalWorkerTest {

public static void main(String[] args) {
ApiClient apiClient = new ApiClient("http://localhost:8080/api");
//apiClient.setUseGRPC("localhost", 8090);

OrkesClients clients = new OrkesClients(apiClient);
TaskClient taskClient = clients.getTaskClient();

Expand Down

0 comments on commit 6d7a128

Please sign in to comment.