Skip to content

Commit

Permalink
Add Basic HTTP Auth to OAuthHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
Awk34 committed Nov 14, 2024
1 parent 5292457 commit c6f9059
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ public class OAuthProvider {
private final String tokenRefreshURL;
@Nullable
private final OAuthClientCredentials clientCreds;
private final CredentialEncodingStrategy strategy;

public OAuthProvider(String name,
String loginURL,
String tokenRefreshURL,
@Nullable OAuthClientCredentials clientCreds) {
@Nullable OAuthClientCredentials clientCreds,
@Nullable CredentialEncodingStrategy strategy) {
this.name = name;
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientCreds = clientCreds;
this.strategy = strategy;
}

public String getName() {
Expand All @@ -54,6 +57,17 @@ public OAuthClientCredentials getClientCredentials() {
return clientCreds;
}

public CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}

public enum CredentialEncodingStrategy {
// (default) Sends client ID & secret as part of the POST request body
FORM_BODY,
// Sends client ID & secret as part of a HTTP Basic Auth header
BASIC_AUTH,
}

public static Builder newBuilder() {
return new Builder();
}
Expand All @@ -66,6 +80,7 @@ public static class Builder {
private String loginURL;
private String tokenRefreshURL;
private OAuthClientCredentials clientCreds;
private CredentialEncodingStrategy strategy;

public Builder() {}

Expand All @@ -89,11 +104,20 @@ public Builder withClientCredentials(@Nullable OAuthClientCredentials clientCred
return this;
}

public Builder withCredentialEncodingStrategy(@Nullable CredentialEncodingStrategy strategy) {
this.strategy = strategy;
return this;
}

public OAuthProvider build() {
Preconditions.checkNotNull(name, "OAuth provider name missing");
Preconditions.checkNotNull(loginURL, "Login URL missing");
Preconditions.checkNotNull(tokenRefreshURL, "Token refresh URL missing");
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds);
// Default to FORM_BODY strategy
if (strategy == null) {
this.strategy = CredentialEncodingStrategy.FORM_BODY;
}
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds, strategy);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ public class PutOAuthProviderRequest {
private final String tokenRefreshURL;
private final String clientId;
private final String clientSecret;
private final OAuthProviderRequest.CredentialEncodingStrategy strategy;

public PutOAuthProviderRequest(String loginURL, String tokenRefreshURL, String clientId, String clientSecret) {
public PutOAuthProviderRequest(String loginURL, String tokenRefreshURL, String clientId, String clientSecret, OAuthProviderRequest.CredentialEncodingStrategy strategy) {
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientId = clientId;
this.clientSecret = clientSecret;
this.strategy = strategy;
}

public String getLoginURL() {
Expand All @@ -48,4 +50,8 @@ public String getClientId() {
public String getClientSecret() {
return clientSecret;
}

public OAuthProviderRequest.CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.cdap.cdap.datapipeline.oauth.OAuthStoreException;
import io.cdap.cdap.datapipeline.oauth.PutOAuthCredentialRequest;
import io.cdap.cdap.datapipeline.oauth.PutOAuthProviderRequest;
import io.cdap.cdap.datapipeline.oauth.PutOAuthProviderRequest.CredentialEncodingStrategy;
import io.cdap.cdap.datapipeline.oauth.RefreshTokenResponse;
import io.cdap.common.http.HttpRequest;
import io.cdap.common.http.HttpRequests;
Expand All @@ -43,6 +44,7 @@
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Optional;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
Expand Down Expand Up @@ -115,6 +117,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
PutOAuthProviderRequest putOAuthProviderRequest = GSON.fromJson(
StandardCharsets.UTF_8.decode(request.getContent()).toString(),
PutOAuthProviderRequest.class);
CredentialEncodingStrategy strategy = putOAuthProviderRequest.getCredentialEncodingStrategy();
// Validate URLs
URL loginURL = new URL(putOAuthProviderRequest.getLoginURL());
URL tokenRefreshURL = new URL(putOAuthProviderRequest.getTokenRefreshURL());
Expand All @@ -132,6 +135,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
.withLoginURL(loginURL.toString())
.withTokenRefreshURL(tokenRefreshURL.toString())
.withClientCredentials(clientCredentials)
.withCredentialEncodingStrategy(strategy)
.build();
oauthStore.writeProvider(provider, reuseClientCredentials);
responder.sendStatus(HttpURLConnection.HTTP_OK);
Expand Down Expand Up @@ -310,13 +314,29 @@ private boolean checkCredIsValid(HttpResponse response) throws OAuthServiceExcep
private HttpRequest createGetRefreshTokenRequest(OAuthProvider provider, String code, String redirectURI)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(String.format(
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();

String body;
switch (strategy) {
case BASIC_AUTH:
body = String.format("code=%s&redirect_uri=%s&grant_type=authorization_code", code, redirectURI);
case FORM_BODY:
default:
body = String.format(
"code=%s&redirect_uri=%s&client_id=%s&client_secret=%s&grant_type=authorization_code",
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret()))
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED)
.build();
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret());
}

try {
HttpRequest.Builder requestBuilder = HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(body)
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED);

if (strategy == CredentialEncodingStrategy.BASIC_AUTH) {
requestBuilder.addHeader(HttpHeaders.AUTHORIZATION, getBasicAuthHeader(clientCreds));
}

return requestBuilder.build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
Expand All @@ -325,19 +345,39 @@ private HttpRequest createGetRefreshTokenRequest(OAuthProvider provider, String
private HttpRequest createGetAccessTokenRequest(OAuthProvider provider, String refreshToken)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(
String.format("grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s",
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();

String body;
switch (strategy) {
case BASIC_AUTH:
body = String.format("grant_type=refresh_token&refresh_token=%s", refreshToken);
break;
case FORM_BODY: // fall-through
default:
body = String.format("grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s",
clientCreds.getClientId(),
clientCreds.getClientSecret(),
refreshToken))
.build();
refreshToken);
}

try {
HttpRequest.Builder requestBuilder = HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(body);

if (strategy == CredentialEncodingStrategy.BASIC_AUTH) {
requestBuilder.addHeader(HttpHeaders.AUTHORIZATION, getBasicAuthHeader(clientCreds));
}

return requestBuilder.build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
}

private String getBasicAuthHeader(OAuthClientCredentials clientCreds) {
return String.format("Basic %s", Base64.getEncoder().encode(String.format("%s:%s", clientCreds.getClientId(), clientCreds.getClientSecret()).getBytes()));
}

private OAuthProvider getProvider(String provider) throws OAuthServiceException {
try {
Optional<OAuthProvider> providerOptional = oauthStore.getProvider(provider);
Expand Down

0 comments on commit c6f9059

Please sign in to comment.