Skip to content

Commit

Permalink
[KYUUBI #6669] Track the client ip for internal kyuubi RESTful requests
Browse files Browse the repository at this point in the history
# 🔍 Description
## Issue References 🔗

As title, track the clientIp for internal kyuubi RESTful requests.
Likes
 - getBatch
 - getBatchLocalLog
 - deleteBatch

## Describe Your Solution 🔧

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

## Types of changes 🔖

- [ ] Bugfix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

## Test Plan 🧪

#### Behavior Without This Pull Request ⚰️

#### Behavior With This Pull Request 🎉

#### Related Unit Tests

---

# Checklist 📝

- [x] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html)

**Be nice. Be informative.**

Closes #6669 from turboFei/additional_headers.

Closes #6669

87f144e [Wang, Fei] headers
8dd7aca [Wang, Fei] track the client ip
afc78f2 [Wang, Fei] proxy ip

Authored-by: Wang, Fei <[email protected]>
Signed-off-by: Wang, Fei <[email protected]>
  • Loading branch information
turboFei committed Sep 6, 2024
1 parent d7219fc commit db5ce0c
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ public Batch createBatch(BatchRequest request, File resourceFile, List<String> e
}

public Batch getBatchById(String batchId) {
return getBatchById(batchId, Collections.emptyMap());
}

public Batch getBatchById(String batchId, Map<String, String> headers) {
String path = String.format("%s/%s", API_BASE_PATH, batchId);
return this.getClient().get(path, null, Batch.class, client.getAuthHeader());
return this.getClient().get(path, null, Batch.class, client.getAuthHeader(), headers);
}

public GetBatchesResponse listBatches(
Expand Down Expand Up @@ -131,12 +135,17 @@ public GetBatchesResponse listBatches(
}

public OperationLog getBatchLocalLog(String batchId, int from, int size) {
return getBatchLocalLog(batchId, from, size, Collections.emptyMap());
}

public OperationLog getBatchLocalLog(
String batchId, int from, int size, Map<String, String> headers) {
Map<String, Object> params = new HashMap<>();
params.put("from", from);
params.put("size", size);

String path = String.format("%s/%s/localLog", API_BASE_PATH, batchId);
return this.getClient().get(path, params, OperationLog.class, client.getAuthHeader());
return this.getClient().get(path, params, OperationLog.class, client.getAuthHeader(), headers);
}

/**
Expand All @@ -156,8 +165,13 @@ public CloseBatchResponse deleteBatch(String batchId, String hs2ProxyUser) {
}

public CloseBatchResponse deleteBatch(String batchId) {
return deleteBatch(batchId, Collections.emptyMap());
}

public CloseBatchResponse deleteBatch(String batchId, Map<String, String> headers) {
String path = String.format("%s/%s", API_BASE_PATH, batchId);
return this.getClient().delete(path, null, CloseBatchResponse.class, client.getAuthHeader());
return this.getClient()
.delete(path, null, CloseBatchResponse.class, client.getAuthHeader(), headers);
}

private IRestClient getClient() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,83 @@

package org.apache.kyuubi.client;

import java.util.Collections;
import java.util.Map;
import org.apache.kyuubi.client.api.v1.dto.MultiPart;

/** A underlying http client interface for common rest request. */
public interface IRestClient extends AutoCloseable {
<T> T get(String path, Map<String, Object> params, Class<T> type, String authHeader);
<T> T get(
String path,
Map<String, Object> params,
Class<T> type,
String authHeader,
Map<String, String> headers);

String get(String path, Map<String, Object> params, String authHeader);
default <T> T get(String path, Map<String, Object> params, Class<T> type, String authHeader) {
return get(path, params, type, authHeader, Collections.emptyMap());
}

<T> T post(String path, String body, Class<T> type, String authHeader);
String get(
String path, Map<String, Object> params, String authHeader, Map<String, String> headers);

<T> T post(String path, Map<String, MultiPart> multiPartMap, Class<T> type, String authHeader);
default String get(String path, Map<String, Object> params, String authHeader) {
return get(path, params, authHeader, Collections.emptyMap());
}

String post(String path, String body, String authHeader);
<T> T post(
String path, String body, Class<T> type, String authHeader, Map<String, String> headers);

<T> T put(String path, String body, Class<T> type, String authHeader);
default <T> T post(String path, String body, Class<T> type, String authHeader) {
return post(path, body, type, authHeader, Collections.emptyMap());
}

String put(String path, String body, String authHeader);
<T> T post(
String path,
Map<String, MultiPart> multiPartMap,
Class<T> type,
String authHeader,
Map<String, String> headers);

<T> T delete(String path, Map<String, Object> params, Class<T> type, String authHeader);
default <T> T post(
String path, Map<String, MultiPart> multiPartMap, Class<T> type, String authHeader) {
return post(path, multiPartMap, type, authHeader, Collections.emptyMap());
}

String delete(String path, Map<String, Object> params, String authHeader);
String post(String path, String body, String authHeader, Map<String, String> headers);

default String post(String path, String body, String authHeader) {
return post(path, body, authHeader, Collections.emptyMap());
}

<T> T put(
String path, String body, Class<T> type, String authHeader, Map<String, String> headers);

default <T> T put(String path, String body, Class<T> type, String authHeader) {
return put(path, body, type, authHeader, Collections.emptyMap());
}

String put(String path, String body, String authHeader, Map<String, String> headers);

default String put(String path, String body, String authHeader) {
return put(path, body, authHeader, Collections.emptyMap());
}

<T> T delete(
String path,
Map<String, Object> params,
Class<T> type,
String authHeader,
Map<String, String> headers);

default <T> T delete(String path, Map<String, Object> params, Class<T> type, String authHeader) {
return delete(path, params, type, authHeader, Collections.emptyMap());
}

String delete(
String path, Map<String, Object> params, String authHeader, Map<String, String> headers);

default String delete(String path, Map<String, Object> params, String authHeader) {
return delete(path, params, authHeader, Collections.emptyMap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,45 @@ public void close() throws Exception {
}

@Override
public <T> T get(String path, Map<String, Object> params, Class<T> type, String authHeader) {
String responseJson = get(path, params, authHeader);
public <T> T get(
String path,
Map<String, Object> params,
Class<T> type,
String authHeader,
Map<String, String> headers) {
String responseJson = get(path, params, authHeader, headers);
return JsonUtils.fromJson(responseJson, type);
}

@Override
public String get(String path, Map<String, Object> params, String authHeader) {
return doRequest(buildURI(path, params), authHeader, RequestBuilder.get());
public String get(
String path, Map<String, Object> params, String authHeader, Map<String, String> headers) {
return doRequest(buildURI(path, params), authHeader, RequestBuilder.get(), headers);
}

@Override
public <T> T post(String path, String body, Class<T> type, String authHeader) {
String responseJson = post(path, body, authHeader);
public <T> T post(
String path, String body, Class<T> type, String authHeader, Map<String, String> headers) {
String responseJson = post(path, body, authHeader, headers);
return JsonUtils.fromJson(responseJson, type);
}

@Override
public String post(String path, String body, String authHeader) {
public String post(String path, String body, String authHeader, Map<String, String> headers) {
RequestBuilder postRequestBuilder = RequestBuilder.post();
if (body != null) {
postRequestBuilder.setEntity(new StringEntity(body, StandardCharsets.UTF_8));
}
return doRequest(buildURI(path), authHeader, postRequestBuilder);
return doRequest(buildURI(path), authHeader, postRequestBuilder, headers);
}

@Override
public <T> T post(
String path, Map<String, MultiPart> multiPartMap, Class<T> type, String authHeader) {
String path,
Map<String, MultiPart> multiPartMap,
Class<T> type,
String authHeader,
Map<String, String> headers) {
MultipartEntityBuilder entityBuilder =
MultipartEntityBuilder.create().setCharset(StandardCharsets.UTF_8);
multiPartMap.forEach(
Expand All @@ -122,43 +133,52 @@ public <T> T post(
RequestBuilder postRequestBuilder = RequestBuilder.post(buildURI(path));
postRequestBuilder.setHeader(httpEntity.getContentType());
postRequestBuilder.setEntity(httpEntity);
String responseJson = doRequest(buildURI(path), authHeader, postRequestBuilder);
String responseJson = doRequest(buildURI(path), authHeader, postRequestBuilder, headers);
return JsonUtils.fromJson(responseJson, type);
}

@Override
public <T> T put(String path, String body, Class<T> type, String authHeader) {
String responseJson = put(path, body, authHeader);
public <T> T put(
String path, String body, Class<T> type, String authHeader, Map<String, String> headers) {
String responseJson = put(path, body, authHeader, headers);
return JsonUtils.fromJson(responseJson, type);
}

@Override
public String put(String path, String body, String authHeader) {
public String put(String path, String body, String authHeader, Map<String, String> headers) {
RequestBuilder putRequestBuilder = RequestBuilder.put();
if (body != null) {
putRequestBuilder.setEntity(new StringEntity(body, StandardCharsets.UTF_8));
}
return doRequest(buildURI(path), authHeader, putRequestBuilder);
return doRequest(buildURI(path), authHeader, putRequestBuilder, headers);
}

@Override
public <T> T delete(String path, Map<String, Object> params, Class<T> type, String authHeader) {
String responseJson = delete(path, params, authHeader);
public <T> T delete(
String path,
Map<String, Object> params,
Class<T> type,
String authHeader,
Map<String, String> headers) {
String responseJson = delete(path, params, authHeader, headers);
return JsonUtils.fromJson(responseJson, type);
}

@Override
public String delete(String path, Map<String, Object> params, String authHeader) {
return doRequest(buildURI(path, params), authHeader, RequestBuilder.delete());
public String delete(
String path, Map<String, Object> params, String authHeader, Map<String, String> headers) {
return doRequest(buildURI(path, params), authHeader, RequestBuilder.delete(), headers);
}

private String doRequest(URI uri, String authHeader, RequestBuilder requestBuilder) {
private String doRequest(
URI uri, String authHeader, RequestBuilder requestBuilder, Map<String, String> headers) {
String response;
try {
if (requestBuilder.getFirstHeader(HttpHeaders.CONTENT_TYPE) == null) {
requestBuilder.setHeader(
HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType());
}
headers.forEach(requestBuilder::setHeader);
if (StringUtils.isNotBlank(authHeader)) {
requestBuilder.setHeader(HttpHeaders.AUTHORIZATION, authHeader);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
kyuubiInstance =>
new InternalRestClient(
kyuubiInstance,
fe.getConf.get(FRONTEND_PROXY_HTTP_CLIENT_IP_HEADER),
internalSocketTimeout,
internalConnectTimeout,
internalSecurityEnabled,
Expand Down Expand Up @@ -347,7 +348,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
} else {
val internalRestClient = getInternalRestClient(metadata.kyuubiInstance)
try {
internalRestClient.getBatch(userName, batchId)
internalRestClient.getBatch(userName, fe.getIpAddress, batchId)
} catch {
case e: KyuubiRestException =>
error(s"Error redirecting get batch[$batchId] to ${metadata.kyuubiInstance}", e)
Expand Down Expand Up @@ -458,7 +459,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
new OperationLog(dummyLogs, dummyLogs.size)
} else if (fe.connectionUrl != metadata.kyuubiInstance) {
val internalRestClient = getInternalRestClient(metadata.kyuubiInstance)
internalRestClient.getBatchLocalLog(userName, batchId, from, size)
internalRestClient.getBatchLocalLog(userName, fe.getIpAddress, batchId, from, size)
} else if (batchV2Enabled(metadata.requestConf) &&
// in batch v2 impl, the operation state is changed from PENDING to RUNNING
// before being added to SessionManager.
Expand Down Expand Up @@ -520,7 +521,7 @@ private[v1] class BatchesResource extends ApiRequestContext with Logging {
info(s"Redirecting delete batch[$batchId] to ${metadata.kyuubiInstance}")
val internalRestClient = getInternalRestClient(metadata.kyuubiInstance)
try {
internalRestClient.deleteBatch(metadata.username, batchId)
internalRestClient.deleteBatch(metadata.username, fe.getIpAddress, batchId)
} catch {
case e: KyuubiRestException =>
error(s"Error redirecting delete batch[$batchId] to ${metadata.kyuubiInstance}", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.kyuubi.server.api.v1

import java.util.Base64

import scala.collection.JavaConverters._

import org.apache.kyuubi.client.{BatchRestApi, KyuubiRestClient}
import org.apache.kyuubi.client.api.v1.dto.{Batch, CloseBatchResponse, OperationLog}
import org.apache.kyuubi.client.auth.AuthHeaderGenerator
Expand All @@ -38,6 +40,7 @@ import org.apache.kyuubi.service.authentication.InternalSecurityAccessor
*/
class InternalRestClient(
kyuubiInstance: String,
proxyClientIpHeader: String,
socketTimeout: Int,
connectTimeout: Int,
securityEnabled: Boolean,
Expand All @@ -51,21 +54,30 @@ class InternalRestClient(

private val internalBatchRestApi = new BatchRestApi(initKyuubiRestClient())

def getBatch(user: String, batchId: String): Batch = {
def getBatch(user: String, clientIp: String, batchId: String): Batch = {
withAuthUser(user) {
internalBatchRestApi.getBatchById(batchId)
internalBatchRestApi.getBatchById(batchId, Map(proxyClientIpHeader -> clientIp).asJava)
}
}

def getBatchLocalLog(user: String, batchId: String, from: Int, size: Int): OperationLog = {
def getBatchLocalLog(
user: String,
clientIp: String,
batchId: String,
from: Int,
size: Int): OperationLog = {
withAuthUser(user) {
internalBatchRestApi.getBatchLocalLog(batchId, from, size)
internalBatchRestApi.getBatchLocalLog(
batchId,
from,
size,
Map(proxyClientIpHeader -> clientIp).asJava)
}
}

def deleteBatch(user: String, batchId: String): CloseBatchResponse = {
def deleteBatch(user: String, clientIp: String, batchId: String): CloseBatchResponse = {
withAuthUser(user) {
internalBatchRestApi.deleteBatch(batchId)
internalBatchRestApi.deleteBatch(batchId, Map(proxyClientIpHeader -> clientIp).asJava)
}
}

Expand Down

0 comments on commit db5ce0c

Please sign in to comment.