Skip to content

Commit

Permalink
clean up job id handling in servlets
Browse files Browse the repository at this point in the history
  • Loading branch information
eschultink committed Jul 16, 2024
1 parent b24fe0b commit 7747143
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import com.google.cloud.datastore.Datastore;
import com.google.cloud.datastore.DatastoreOptions;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import lombok.AllArgsConstructor;

import javax.inject.Inject;
import javax.inject.Singleton;
import java.net.URLEncoder;
import java.util.Optional;

/**
Expand All @@ -28,6 +30,9 @@ public static class Params {
public static final String DATASTORE_DATABASE_ID = "dsDatabaseId";
public static final String DATASTORE_NAMESPACE = "dsNamespace";
public static final String DATASTORE_HOST = "dsHost";

//originally defined per-handler in the gae pipelines project
public static final String ROOT_PIPELINE_ID = "root_pipeline_id";
}

public PipelineBackEnd buildBackendFromRequest(HttpServletRequest request) {
Expand Down Expand Up @@ -56,4 +61,17 @@ public Datastore buildDatastoreFromRequest(HttpServletRequest request) {
Optional<String> getParam(HttpServletRequest request, String name) {
return Optional.ofNullable(request.getParameter(name));
}

// instead of URL-safe encoded keys, should be base64-encode them to avoid this issue???

// deals with fact that HttpServletRequest *decodes* url params, so even if pipeline id was originally url-encoded,
// we need to ensure it remains so
public Optional<String> getJobId(HttpServletRequest request, String paramName) {
return getParam(request, paramName).map(URLEncoder::encode);
}

public String getRootPipelineId(HttpServletRequest request) throws ServletException {
return getJobId(request, Params.ROOT_PIPELINE_ID)
.orElseThrow(() -> new ServletException(Params.ROOT_PIPELINE_ID + " parameter not found."));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,15 @@
public class AbortJobHandler {

public static final String PATH_COMPONENT = "rpc/abort";
private static final String ROOT_PIPELINE_ID = "root_pipeline_id";

final JobRunServiceComponent component;
final RequestUtils requestUtils;


public void doGet(HttpServletRequest req, HttpServletResponse resp)
throws IOException, ServletException {
String rootJobHandle = req.getParameter(ROOT_PIPELINE_ID);
if (null == rootJobHandle) {
throw new ServletException(ROOT_PIPELINE_ID + " parameter not found.");
}
String rootJobHandle = requestUtils.getRootPipelineId(req);

try {
StepExecutionComponent stepExecutionComponent =
component.stepExecutionComponent(new StepExecutionModule(requestUtils.buildBackendFromRequest(req)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,14 @@
public class DeleteJobHandler {

public static final String PATH_COMPONENT = "rpc/delete";
private static final String ROOT_PIPELINE_ID = "root_pipeline_id";

final JobRunServiceComponent component;
final RequestUtils requestUtils;

public void doGet(HttpServletRequest req, HttpServletResponse resp)
throws IOException, ServletException {
String rootJobHandle = req.getParameter(ROOT_PIPELINE_ID);
if (null == rootJobHandle) {
throw new ServletException(ROOT_PIPELINE_ID + " parameter not found.");
}

String rootJobHandle = requestUtils.getRootPipelineId(req);

StepExecutionComponent stepExecutionComponent =
component.stepExecutionComponent(new StepExecutionModule(requestUtils.buildBackendFromRequest(req)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import lombok.AllArgsConstructor;

import java.io.IOException;
import java.net.URLEncoder;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
Expand All @@ -48,12 +49,8 @@ public class JsonTreeHandler {
public void doGet(HttpServletRequest req, HttpServletResponse resp)
throws ServletException {

String rootJobHandle = req.getParameter(ROOT_PIPELINE_ID);
if (null == rootJobHandle) {
throw new ServletException(ROOT_PIPELINE_ID + " parameter not found.");
}
String rootJobHandle = requestUtils.getRootPipelineId(req);
try {

StepExecutionComponent stepExecutionComponent =
component.stepExecutionComponent(new StepExecutionModule(requestUtils.buildBackendFromRequest(req)));
PipelineRunner pipelineRunner = stepExecutionComponent.pipelineRunner();
Expand All @@ -65,10 +62,11 @@ public void doGet(HttpServletRequest req, HttpServletResponse resp)
resp.sendError(HttpServletResponse.SC_NOT_FOUND);
return;
}
String rootJobKey = jobInfo.getRootJobKey().getName();
String rootJobKey = jobInfo.getRootJobKey().toUrlSafe();
if (!rootJobKey.equals(rootJobHandle)) {
//in effect, value passed to servlet for root_pipeline_id is not in fact the id of a root job of a pipeline
resp.addHeader(ROOT_PIPELINE_ID, rootJobKey);
resp.sendError(449, rootJobKey);
resp.sendError(449, "parsed root_pipeline_id (" + rootJobHandle + ") has JobInfo from different root job : "+ rootJobKey);
return;
}
PipelineObjects pipelineObjects = pipelineRunner.queryFullPipeline(rootJobKey);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.google.appengine.tools.mapreduce.impl.util;

import com.google.cloud.datastore.Key;
import org.junit.jupiter.api.Test;

import javax.servlet.http.HttpServletRequest;

import java.net.URLDecoder;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class RequestUtilsTest {

public final String ENCODED_EXAMPLE = "partition_id+%7B%0A++project_id%3A+%22worklytics-dev%22%0A%7D%0Apath+%7B%0A++kind%3A+%22pipeline-job%22%0A++name%3A+%22c6fa877b-81a6-4e17-a8f7-62268036db97%22%0A%7D%0A";

@Test
void getJobId() {

Key example =
Key.fromUrlSafe(ENCODED_EXAMPLE);

RequestUtils requestUtils = new RequestUtils();
HttpServletRequest request = mock(HttpServletRequest.class);

// HttpServletRequest *decodes* url params
when(request.getParameter("root_pipeline_id"))
.thenReturn(URLDecoder.decode(example.toUrlSafe()));

//in effect, ensure round-trip encode and decode works
assertEquals(ENCODED_EXAMPLE,
requestUtils.getJobId(request, "root_pipeline_id").get());
}
}

0 comments on commit 7747143

Please sign in to comment.