diff --git a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/BatchRestApi.java b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/BatchRestApi.java index aceba9507ae..b846067e0be 100644 --- a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/BatchRestApi.java +++ b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/BatchRestApi.java @@ -74,8 +74,12 @@ public Batch createBatch(BatchRequest request, File resourceFile, List e } public Batch getBatchById(String batchId) { + return getBatchById(batchId, Collections.emptyMap()); + } + + public Batch getBatchById(String batchId, Map headers) { String path = String.format("%s/%s", API_BASE_PATH, batchId); - return this.getClient().get(path, null, Batch.class, client.getAuthHeader()); + return this.getClient().get(path, null, Batch.class, client.getAuthHeader(), headers); } public GetBatchesResponse listBatches( @@ -131,12 +135,17 @@ public GetBatchesResponse listBatches( } public OperationLog getBatchLocalLog(String batchId, int from, int size) { + return getBatchLocalLog(batchId, from, size, Collections.emptyMap()); + } + + public OperationLog getBatchLocalLog( + String batchId, int from, int size, Map headers) { Map params = new HashMap<>(); params.put("from", from); params.put("size", size); String path = String.format("%s/%s/localLog", API_BASE_PATH, batchId); - return this.getClient().get(path, params, OperationLog.class, client.getAuthHeader()); + return this.getClient().get(path, params, OperationLog.class, client.getAuthHeader(), headers); } /** @@ -156,8 +165,13 @@ public CloseBatchResponse deleteBatch(String batchId, String hs2ProxyUser) { } public CloseBatchResponse deleteBatch(String batchId) { + return deleteBatch(batchId, Collections.emptyMap()); + } + + public CloseBatchResponse deleteBatch(String batchId, Map headers) { String path = String.format("%s/%s", API_BASE_PATH, batchId); - return this.getClient().delete(path, null, CloseBatchResponse.class, client.getAuthHeader()); + return this.getClient() + .delete(path, null, CloseBatchResponse.class, client.getAuthHeader(), headers); } private IRestClient getClient() { diff --git a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/IRestClient.java b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/IRestClient.java index 0eaffebd246..43e15c6d54a 100644 --- a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/IRestClient.java +++ b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/IRestClient.java @@ -17,26 +17,83 @@ package org.apache.kyuubi.client; +import java.util.Collections; import java.util.Map; import org.apache.kyuubi.client.api.v1.dto.MultiPart; /** A underlying http client interface for common rest request. */ public interface IRestClient extends AutoCloseable { - T get(String path, Map params, Class type, String authHeader); + T get( + String path, + Map params, + Class type, + String authHeader, + Map headers); - String get(String path, Map params, String authHeader); + default T get(String path, Map params, Class type, String authHeader) { + return get(path, params, type, authHeader, Collections.emptyMap()); + } - T post(String path, String body, Class type, String authHeader); + String get( + String path, Map params, String authHeader, Map headers); - T post(String path, Map multiPartMap, Class type, String authHeader); + default String get(String path, Map params, String authHeader) { + return get(path, params, authHeader, Collections.emptyMap()); + } - String post(String path, String body, String authHeader); + T post( + String path, String body, Class type, String authHeader, Map headers); - T put(String path, String body, Class type, String authHeader); + default T post(String path, String body, Class type, String authHeader) { + return post(path, body, type, authHeader, Collections.emptyMap()); + } - String put(String path, String body, String authHeader); + T post( + String path, + Map multiPartMap, + Class type, + String authHeader, + Map headers); - T delete(String path, Map params, Class type, String authHeader); + default T post( + String path, Map multiPartMap, Class type, String authHeader) { + return post(path, multiPartMap, type, authHeader, Collections.emptyMap()); + } - String delete(String path, Map params, String authHeader); + String post(String path, String body, String authHeader, Map headers); + + default String post(String path, String body, String authHeader) { + return post(path, body, authHeader, Collections.emptyMap()); + } + + T put( + String path, String body, Class type, String authHeader, Map headers); + + default T put(String path, String body, Class type, String authHeader) { + return put(path, body, type, authHeader, Collections.emptyMap()); + } + + String put(String path, String body, String authHeader, Map headers); + + default String put(String path, String body, String authHeader) { + return put(path, body, authHeader, Collections.emptyMap()); + } + + T delete( + String path, + Map params, + Class type, + String authHeader, + Map headers); + + default T delete(String path, Map params, Class type, String authHeader) { + return delete(path, params, type, authHeader, Collections.emptyMap()); + } + + String delete( + String path, Map params, String authHeader, Map headers); + + default String delete(String path, Map params, String authHeader) { + return delete(path, params, authHeader, Collections.emptyMap()); + } } diff --git a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/RestClient.java b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/RestClient.java index 86de74cf670..eaea887fae8 100644 --- a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/RestClient.java +++ b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/RestClient.java @@ -70,34 +70,45 @@ public void close() throws Exception { } @Override - public T get(String path, Map params, Class type, String authHeader) { - String responseJson = get(path, params, authHeader); + public T get( + String path, + Map params, + Class type, + String authHeader, + Map headers) { + String responseJson = get(path, params, authHeader, headers); return JsonUtils.fromJson(responseJson, type); } @Override - public String get(String path, Map params, String authHeader) { - return doRequest(buildURI(path, params), authHeader, RequestBuilder.get()); + public String get( + String path, Map params, String authHeader, Map headers) { + return doRequest(buildURI(path, params), authHeader, RequestBuilder.get(), headers); } @Override - public T post(String path, String body, Class type, String authHeader) { - String responseJson = post(path, body, authHeader); + public T post( + String path, String body, Class type, String authHeader, Map headers) { + String responseJson = post(path, body, authHeader, headers); return JsonUtils.fromJson(responseJson, type); } @Override - public String post(String path, String body, String authHeader) { + public String post(String path, String body, String authHeader, Map headers) { RequestBuilder postRequestBuilder = RequestBuilder.post(); if (body != null) { postRequestBuilder.setEntity(new StringEntity(body, StandardCharsets.UTF_8)); } - return doRequest(buildURI(path), authHeader, postRequestBuilder); + return doRequest(buildURI(path), authHeader, postRequestBuilder, headers); } @Override public T post( - String path, Map multiPartMap, Class type, String authHeader) { + String path, + Map multiPartMap, + Class type, + String authHeader, + Map headers) { MultipartEntityBuilder entityBuilder = MultipartEntityBuilder.create().setCharset(StandardCharsets.UTF_8); multiPartMap.forEach( @@ -122,43 +133,52 @@ public T post( RequestBuilder postRequestBuilder = RequestBuilder.post(buildURI(path)); postRequestBuilder.setHeader(httpEntity.getContentType()); postRequestBuilder.setEntity(httpEntity); - String responseJson = doRequest(buildURI(path), authHeader, postRequestBuilder); + String responseJson = doRequest(buildURI(path), authHeader, postRequestBuilder, headers); return JsonUtils.fromJson(responseJson, type); } @Override - public T put(String path, String body, Class type, String authHeader) { - String responseJson = put(path, body, authHeader); + public T put( + String path, String body, Class type, String authHeader, Map headers) { + String responseJson = put(path, body, authHeader, headers); return JsonUtils.fromJson(responseJson, type); } @Override - public String put(String path, String body, String authHeader) { + public String put(String path, String body, String authHeader, Map headers) { RequestBuilder putRequestBuilder = RequestBuilder.put(); if (body != null) { putRequestBuilder.setEntity(new StringEntity(body, StandardCharsets.UTF_8)); } - return doRequest(buildURI(path), authHeader, putRequestBuilder); + return doRequest(buildURI(path), authHeader, putRequestBuilder, headers); } @Override - public T delete(String path, Map params, Class type, String authHeader) { - String responseJson = delete(path, params, authHeader); + public T delete( + String path, + Map params, + Class type, + String authHeader, + Map headers) { + String responseJson = delete(path, params, authHeader, headers); return JsonUtils.fromJson(responseJson, type); } @Override - public String delete(String path, Map params, String authHeader) { - return doRequest(buildURI(path, params), authHeader, RequestBuilder.delete()); + public String delete( + String path, Map params, String authHeader, Map headers) { + return doRequest(buildURI(path, params), authHeader, RequestBuilder.delete(), headers); } - private String doRequest(URI uri, String authHeader, RequestBuilder requestBuilder) { + private String doRequest( + URI uri, String authHeader, RequestBuilder requestBuilder, Map headers) { String response; try { if (requestBuilder.getFirstHeader(HttpHeaders.CONTENT_TYPE) == null) { requestBuilder.setHeader( HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType()); } + headers.forEach(requestBuilder::setHeader); if (StringUtils.isNotBlank(authHeader)) { requestBuilder.setHeader(HttpHeaders.AUTHORIZATION, authHeader); } diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/BatchesResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/BatchesResource.scala index 6e58742d617..3fd5ddbeaa8 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/BatchesResource.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/BatchesResource.scala @@ -76,6 +76,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging { kyuubiInstance => new InternalRestClient( kyuubiInstance, + fe.getConf.get(FRONTEND_PROXY_HTTP_CLIENT_IP_HEADER), internalSocketTimeout, internalConnectTimeout, internalSecurityEnabled, @@ -347,7 +348,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging { } else { val internalRestClient = getInternalRestClient(metadata.kyuubiInstance) try { - internalRestClient.getBatch(userName, batchId) + internalRestClient.getBatch(userName, fe.getIpAddress, batchId) } catch { case e: KyuubiRestException => error(s"Error redirecting get batch[$batchId] to ${metadata.kyuubiInstance}", e) @@ -458,7 +459,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging { new OperationLog(dummyLogs, dummyLogs.size) } else if (fe.connectionUrl != metadata.kyuubiInstance) { val internalRestClient = getInternalRestClient(metadata.kyuubiInstance) - internalRestClient.getBatchLocalLog(userName, batchId, from, size) + internalRestClient.getBatchLocalLog(userName, fe.getIpAddress, batchId, from, size) } else if (batchV2Enabled(metadata.requestConf) && // in batch v2 impl, the operation state is changed from PENDING to RUNNING // before being added to SessionManager. @@ -520,7 +521,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging { info(s"Redirecting delete batch[$batchId] to ${metadata.kyuubiInstance}") val internalRestClient = getInternalRestClient(metadata.kyuubiInstance) try { - internalRestClient.deleteBatch(metadata.username, batchId) + internalRestClient.deleteBatch(metadata.username, fe.getIpAddress, batchId) } catch { case e: KyuubiRestException => error(s"Error redirecting delete batch[$batchId] to ${metadata.kyuubiInstance}", e) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/InternalRestClient.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/InternalRestClient.scala index 011e0dc4cb1..e6f97efb5ea 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/InternalRestClient.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/InternalRestClient.scala @@ -19,6 +19,8 @@ package org.apache.kyuubi.server.api.v1 import java.util.Base64 +import scala.collection.JavaConverters._ + import org.apache.kyuubi.client.{BatchRestApi, KyuubiRestClient} import org.apache.kyuubi.client.api.v1.dto.{Batch, CloseBatchResponse, OperationLog} import org.apache.kyuubi.client.auth.AuthHeaderGenerator @@ -38,6 +40,7 @@ import org.apache.kyuubi.service.authentication.InternalSecurityAccessor */ class InternalRestClient( kyuubiInstance: String, + proxyClientIpHeader: String, socketTimeout: Int, connectTimeout: Int, securityEnabled: Boolean, @@ -51,21 +54,30 @@ class InternalRestClient( private val internalBatchRestApi = new BatchRestApi(initKyuubiRestClient()) - def getBatch(user: String, batchId: String): Batch = { + def getBatch(user: String, clientIp: String, batchId: String): Batch = { withAuthUser(user) { - internalBatchRestApi.getBatchById(batchId) + internalBatchRestApi.getBatchById(batchId, Map(proxyClientIpHeader -> clientIp).asJava) } } - def getBatchLocalLog(user: String, batchId: String, from: Int, size: Int): OperationLog = { + def getBatchLocalLog( + user: String, + clientIp: String, + batchId: String, + from: Int, + size: Int): OperationLog = { withAuthUser(user) { - internalBatchRestApi.getBatchLocalLog(batchId, from, size) + internalBatchRestApi.getBatchLocalLog( + batchId, + from, + size, + Map(proxyClientIpHeader -> clientIp).asJava) } } - def deleteBatch(user: String, batchId: String): CloseBatchResponse = { + def deleteBatch(user: String, clientIp: String, batchId: String): CloseBatchResponse = { withAuthUser(user) { - internalBatchRestApi.deleteBatch(batchId) + internalBatchRestApi.deleteBatch(batchId, Map(proxyClientIpHeader -> clientIp).asJava) } }