Skip to content

Commit

Permalink
Connect to the correct endpoints based on runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
jianglai committed Sep 19, 2024
1 parent c47f821 commit 7358ca6
Show file tree
Hide file tree
Showing 120 changed files with 635 additions and 491 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import com.google.common.flogger.FluentLogger;
import google.registry.model.EppResource;
import google.registry.persistence.VKey;
import google.registry.request.Action.Service;
import google.registry.request.Action;
import javax.inject.Inject;
import org.joda.time.DateTime;
import org.joda.time.Duration;
Expand Down Expand Up @@ -78,7 +78,7 @@ public void enqueueAsyncResave(
logger.atInfo().log("Enqueuing async re-save of %s to run at %s.", entityKey, whenToResave);
cloudTasksUtils.enqueue(
QUEUE_ASYNC_ACTIONS,
cloudTasksUtils.createPostTaskWithDelay(
ResaveEntityAction.PATH, Service.BACKEND, params, etaDuration));
cloudTasksUtils.createTaskWithDelay(
ResaveEntityAction.class, Action.Method.POST, params, etaDuration));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.common.flogger.FluentLogger;
import google.registry.request.Action;
import google.registry.request.Action.GaeService;
import google.registry.request.Parameter;
import google.registry.request.Response;
import google.registry.request.UrlConnectionService;
Expand All @@ -42,7 +43,7 @@
* --service BACKEND -X POST -u '/_dr/task/executeCannedScript}'}
*/
@Action(
service = Action.Service.BACKEND,
service = GaeService.BACKEND,
path = "/_dr/task/executeCannedScript",
method = {POST, GET},
automaticallyPrintOk = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import google.registry.model.domain.token.BulkPricingPackage;
import google.registry.model.registrar.Registrar;
import google.registry.request.Action;
import google.registry.request.Action.Service;
import google.registry.request.Action.GaeService;
import google.registry.request.auth.Auth;
import google.registry.ui.server.SendEmailUtils;
import google.registry.util.Clock;
Expand All @@ -39,7 +39,7 @@
* An action that checks all {@link BulkPricingPackage} objects for compliance with their max create
* limit.
*/
@Action(service = Service.BACKEND, path = CheckBulkComplianceAction.PATH, auth = Auth.AUTH_ADMIN)
@Action(service = GaeService.BACKEND, path = CheckBulkComplianceAction.PATH, auth = Auth.AUTH_ADMIN)
public class CheckBulkComplianceAction implements Runnable {

public static final String PATH = "/_dr/task/checkBulkCompliance";
Expand Down
182 changes: 124 additions & 58 deletions core/src/main/java/google/registry/batch/CloudTasksUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
package google.registry.batch;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static google.registry.tools.ServiceConnection.getServer;
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.api.gax.rpc.ApiException;
Expand All @@ -39,11 +39,15 @@
import com.google.protobuf.util.Timestamps;
import google.registry.config.CredentialModule.ApplicationDefaultCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.request.Action;
import google.registry.request.Action.Method;
import google.registry.request.Action.Service;
import google.registry.util.Clock;
import google.registry.util.CollectionUtils;
import google.registry.util.GoogleCredentialsBundle;
import google.registry.util.RegistryEnvironment;
import google.registry.util.Retrier;
import java.io.Serial;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
Expand All @@ -58,7 +62,7 @@
/** Utilities for dealing with Cloud Tasks. */
public class CloudTasksUtils implements Serializable {

private static final long serialVersionUID = -7605156291755534069L;
@Serial private static final long serialVersionUID = -7605156291755534069L;
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
private static final Random random = new Random();

Expand Down Expand Up @@ -122,7 +126,7 @@ public ImmutableList<Task> enqueue(String queue, Task... tasks) {
*/
private static String processRequestParameters(
String path,
HttpMethod method,
Method method,
Multimap<String, String> params,
BiConsumer<String, String> putHeadersFunction,
Consumer<ByteString> setBodyFunction) {
Expand All @@ -140,7 +144,7 @@ private static String processRequestParameters(
"%s=%s",
escaper.escape(entry.getKey()), escaper.escape(entry.getValue())))
.collect(toImmutableList()));
if (method.equals(HttpMethod.GET)) {
if (method.equals(Method.GET)) {
return String.format("%s?%s", path, encodedParams);
}
putHeadersFunction.accept(HttpHeaders.CONTENT_TYPE, MediaType.FORM_DATA.toString());
Expand All @@ -155,26 +159,26 @@ private static String processRequestParameters(
* default service account as the principal. That account must have permission to submit tasks to
* Cloud Tasks.
*
* <p>The caller of this method is responsible for passing in the appropriate service based on the
* runtime (GAE/GKE). Use the overload that takes an action class if possible.
*
* @param path the relative URI (staring with a slash and ending without one).
* @param method the HTTP method to be used for the request, only GET and POST are supported.
* @param service the App Engine service to route the request to.
* @param method the HTTP method to be used for the request.
* @param service the GAE/GKE service to route the request to.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @return the enqueued task.
* @see <a
* href=ttps://cloud.google.com/appengine/docs/standard/java/taskqueue/push/creating-tasks#target>Specifyinig
* the worker service</a>
*/
private Task createTask(
String path, HttpMethod method, Service service, Multimap<String, String> params) {
protected Task createTask(
String path, Method method, Service service, Multimap<String, String> params) {
checkArgument(
path != null && !path.isEmpty() && path.charAt(0) == '/',
"The path must start with a '/'.");
checkArgument(
method.equals(HttpMethod.GET) || method.equals(HttpMethod.POST),
"HTTP method %s is used. Only GET and POST are allowed.",
method);
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().setHttpMethod(method);
HttpRequest.Builder requestBuilder =
HttpRequest.newBuilder().setHttpMethod(HttpMethod.valueOf(method.name()));
path =
processRequestParameters(
path, method, params, requestBuilder::putHeaders, requestBuilder::setBody);
Expand All @@ -183,17 +187,52 @@ private Task createTask(
.setServiceAccountEmail(credential.serviceAccount())
.setAudience(oauthClientId);
requestBuilder.setOidcToken(oidcTokenBuilder.build());
String totalPath = String.format("%s%s", getServer(service), path);
String totalPath = String.format("%s%s", service.getServiceUrl(), path);
requestBuilder.setUrl(totalPath);
return Task.newBuilder().setHttpRequest(requestBuilder.build()).build();
}

/**
* Create a {@link Task} to be enqueued.
*
* <p>This uses the standard Cloud Tasks auth format to create and send an OIDC ID token with the
* default service account as the principal. That account must have permission to submit tasks to
* Cloud Tasks.
*
* <p>Prefer this overload over the one where the path and service are explicit defined, as this
* class will automatically determine the service to use based on the action and the runtime.
*
* @param actionClazz the action class to run, must be annotated with {@link Action}.
* @param method the HTTP method to be used for the request.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @return the enqueued task.
* @see <a
* href=ttps://cloud.google.com/appengine/docs/standard/java/taskqueue/push/creating-tasks#target>Specifyinig
* the worker service</a>
*/
public Task createTask(
Class<? extends Runnable> actionClazz, Method method, Multimap<String, String> params) {
Action action = actionClazz.getAnnotation(Action.class);
checkArgument(
action != null,
"Action class %s is not annotated with @Action",
actionClazz.getSimpleName());
String path = action.path();
Service service =
RegistryEnvironment.isOnJetty() ? Action.ServiceGetter.get(action) : action.service();
return createTask(path, method, service, params);
}

/**
* Create a {@link Task} to be enqueued with a random delay up to {@code jitterSeconds}.
*
* <p>The caller of this method is responsible for passing in the appropriate service based on the
* runtime (GAE/GKE). Use the overload that takes an action class if possible.
*
* @param path the relative URI (staring with a slash and ending without one).
* @param method the HTTP method to be used for the request, only GET and POST are supported.
* @param service the App Engine service to route the request to.
* @param method the HTTP method to be used for the request.
* @param service the GAE/GKE service to route the request to.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @param jitterSeconds the number of seconds that a task is randomly delayed up to.
Expand All @@ -202,9 +241,9 @@ private Task createTask(
* href=ttps://cloud.google.com/appengine/docs/standard/java/taskqueue/push/creating-tasks#target>Specifyinig
* the worker service</a>
*/
private Task createTaskWithJitter(
public Task createTaskWithJitter(
String path,
HttpMethod method,
Method method,
Service service,
Multimap<String, String> params,
Optional<Integer> jitterSeconds) {
Expand All @@ -219,12 +258,44 @@ private Task createTaskWithJitter(
Duration.millis(random.nextInt((int) SECONDS.toMillis(jitterSeconds.get()))));
}

/**
* Create a {@link Task} to be enqueued with a random delay up to {@code jitterSeconds}.
*
* <p>Prefer this overload over the one where the path and service are explicit defined, as this
* class will automatically determine the service to use based on the action and the runtime.
*
* @param actionClazz the action class to run, must be annotated with {@link Action}.
* @param method the HTTP method to be used for the request.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @param jitterSeconds the number of seconds that a task is randomly delayed up to.
* @return the enqueued task.
* @see <a
* href=ttps://cloud.google.com/appengine/docs/standard/java/taskqueue/push/creating-tasks#target>Specifyinig
* the worker service</a>
*/
public Task createTaskWithJitter(
Class<? extends Runnable> actionClazz,
Method method,
Multimap<String, String> params,
Optional<Integer> jitterSeconds) {
Action action = getAction(actionClazz);
checkState(
action != null,
"Action class %s is not annotated with @Action",
actionClazz.getSimpleName());
String path = action.path();
Service service =
RegistryEnvironment.isOnJetty() ? Action.ServiceGetter.get(action) : action.service();
return createTaskWithJitter(path, method, service, params, jitterSeconds);
}

/**
* Create a {@link Task} to be enqueued with delay of {@code duration}.
*
* @param path the relative URI (staring with a slash and ending without one).
* @param method the HTTP method to be used for the request, only GET and POST are supported.
* @param service the App Engine service to route the request to.
* @param method the HTTP method to be used for the request.
* @param service the GAE/GKE service to route the request to.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @param delay the amount of time that a task needs to delayed for.
Expand All @@ -235,7 +306,7 @@ private Task createTaskWithJitter(
*/
private Task createTaskWithDelay(
String path,
HttpMethod method,
Method method,
Service service,
Multimap<String, String> params,
Duration delay) {
Expand All @@ -248,58 +319,53 @@ private Task createTaskWithDelay(
.build();
}

public Task createPostTask(String path, Service service, Multimap<String, String> params) {
return createTask(path, HttpMethod.POST, service, params);
}

public Task createGetTask(String path, Service service, Multimap<String, String> params) {
return createTask(path, HttpMethod.GET, service, params);
}

/**
* Create a {@link Task} via HTTP.POST that will be randomly delayed up to {@code jitterSeconds}.
*/
public Task createPostTaskWithJitter(
String path,
Service service,
Multimap<String, String> params,
Optional<Integer> jitterSeconds) {
return createTaskWithJitter(path, HttpMethod.POST, service, params, jitterSeconds);
}

/**
* Create a {@link Task} via HTTP.GET that will be randomly delayed up to {@code jitterSeconds}.
* Create a {@link Task} to be enqueued with delay of {@code duration}.
*
* <p>Prefer this overload over the one where the path and service are explicit defined, as this
* class will automatically determine the service to use based on the action and the runtime.
*
* @param actionClazz the action class to run, must be annotated with {@link Action}.
* @param method the HTTP method to be used for the request.
* @param params a multimap of URL query parameters. Duplicate keys are saved as is, and it is up
* to the server to process the duplicate keys.
* @param delay the amount of time that a task needs to delayed for.
* @return the enqueued task.
* @see <a
* href=ttps://cloud.google.com/appengine/docs/standard/java/taskqueue/push/creating-tasks#target>Specifyinig
* the worker service</a>
*/
public Task createGetTaskWithJitter(
String path,
Service service,
public Task createTaskWithDelay(
Class<? extends Runnable> actionClazz,
Method method,
Multimap<String, String> params,
Optional<Integer> jitterSeconds) {
return createTaskWithJitter(path, HttpMethod.GET, service, params, jitterSeconds);
}

/** Create a {@link Task} via HTTP.POST that will be delayed for {@code delay}. */
public Task createPostTaskWithDelay(
String path, Service service, Multimap<String, String> params, Duration delay) {
return createTaskWithDelay(path, HttpMethod.POST, service, params, delay);
Duration delay) {
Action action = getAction(actionClazz);
String path = action.path();
Service service =
RegistryEnvironment.isOnJetty() ? Action.ServiceGetter.get(action) : action.service();
return createTaskWithDelay(path, method, service, params, delay);
}

/** Create a {@link Task} via HTTP.GET that will be delayed for {@code delay}. */
public Task createGetTaskWithDelay(
String path, Service service, Multimap<String, String> params, Duration delay) {
return createTaskWithDelay(path, HttpMethod.GET, service, params, delay);
private static Action getAction(Class<? extends Runnable> actionClazz) {
Action action = actionClazz.getAnnotation(Action.class);
checkState(
action != null,
"Action class %s is not annotated with @Action",
actionClazz.getSimpleName());
return action;
}

public abstract static class SerializableCloudTasksClient implements Serializable {

private static final long serialVersionUID = 7872861868968535498L;
@Serial private static final long serialVersionUID = 7872861868968535498L;

public abstract Task enqueue(String projectId, String locationId, String queueName, Task task);
}

public static class GcpCloudTasksClient extends SerializableCloudTasksClient {

private static final long serialVersionUID = -5959253033129154037L;
@Serial private static final long serialVersionUID = -5959253033129154037L;

// Use a supplier so that we can use try-with-resources with the client, which implements
// Autocloseable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import google.registry.model.eppoutput.EppOutput;
import google.registry.persistence.transaction.QueryComposer.Comparator;
import google.registry.request.Action;
import google.registry.request.Action.GaeService;
import google.registry.request.Response;
import google.registry.request.auth.Auth;
import google.registry.request.lock.LockHandler;
Expand Down Expand Up @@ -67,7 +68,7 @@
* this action runs, thus alerting us that human action is needed to correctly process the delete.
*/
@Action(
service = Action.Service.BACKEND,
service = GaeService.BACKEND,
path = DeleteExpiredDomainsAction.PATH,
auth = Auth.AUTH_ADMIN)
public class DeleteExpiredDomainsAction implements Runnable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import google.registry.model.reporting.HistoryEntryDao;
import google.registry.persistence.VKey;
import google.registry.request.Action;
import google.registry.request.Action.GaeService;
import google.registry.request.Parameter;
import google.registry.request.auth.Auth;
import google.registry.util.Clock;
Expand All @@ -54,7 +55,7 @@
* production.
*/
@Action(
service = Action.Service.BACKEND,
service = GaeService.BACKEND,
path = "/_dr/task/deleteLoadTestData",
method = POST,
auth = Auth.AUTH_ADMIN)
Expand Down
Loading

0 comments on commit 7358ca6

Please sign in to comment.