Skip to content

Commit

Permalink
Add Basic HTTP Auth to OAuthHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
Awk34 committed Jan 7, 2025
1 parent 3d1422a commit c7c47f0
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ public class OAuthProvider {
private final String tokenRefreshURL;
@Nullable
private final OAuthClientCredentials clientCreds;
private final CredentialEncodingStrategy strategy;

public OAuthProvider(String name,
String loginURL,
String tokenRefreshURL,
@Nullable OAuthClientCredentials clientCreds) {
@Nullable OAuthClientCredentials clientCreds,
@Nullable CredentialEncodingStrategy strategy) {
this.name = name;
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientCreds = clientCreds;
this.strategy = strategy;
}

public String getName() {
Expand All @@ -54,6 +57,17 @@ public OAuthClientCredentials getClientCredentials() {
return clientCreds;
}

public CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}

public enum CredentialEncodingStrategy {
// (default) Sends client ID & secret as part of the POST request body
FORM_BODY,
// Sends client ID & secret as part of a HTTP Basic Auth header
BASIC_AUTH,
}

public static Builder newBuilder() {
return new Builder();
}
Expand All @@ -66,6 +80,7 @@ public static class Builder {
private String loginURL;
private String tokenRefreshURL;
private OAuthClientCredentials clientCreds;
private CredentialEncodingStrategy strategy;

public Builder() {}

Expand All @@ -89,11 +104,20 @@ public Builder withClientCredentials(@Nullable OAuthClientCredentials clientCred
return this;
}

public Builder withCredentialEncodingStrategy(@Nullable CredentialEncodingStrategy strategy) {
this.strategy = strategy;
return this;
}

public OAuthProvider build() {
Preconditions.checkNotNull(name, "OAuth provider name missing");
Preconditions.checkNotNull(loginURL, "Login URL missing");
Preconditions.checkNotNull(tokenRefreshURL, "Token refresh URL missing");
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds);
// Default to FORM_BODY strategy
if (strategy == null) {
this.strategy = CredentialEncodingStrategy.FORM_BODY;
}
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds, strategy);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,27 @@ public class PutOAuthProviderRequest {
private final String tokenRefreshURL;
private final String clientId;
private final String clientSecret;
private final OAuthProvider.CredentialEncodingStrategy strategy;

public PutOAuthProviderRequest(String loginURL, String tokenRefreshURL, String clientId, String clientSecret) {
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientId = clientId;
this.clientSecret = clientSecret;
this.strategy = OAuthProvider.CredentialEncodingStrategy.FORM_BODY;
}

public PutOAuthProviderRequest(
String loginURL,
String tokenRefreshURL,
String clientId,
String clientSecret,
OAuthProvider.CredentialEncodingStrategy strategy) {
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientId = clientId;
this.clientSecret = clientSecret;
this.strategy = strategy;
}

public String getLoginURL() {
Expand All @@ -48,4 +63,8 @@ public String getClientId() {
public String getClientSecret() {
return clientSecret;
}

public OAuthProvider.CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.cdap.cdap.datapipeline.oauth.GetAccessTokenResponse;
import io.cdap.cdap.datapipeline.oauth.OAuthClientCredentials;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider.CredentialEncodingStrategy;
import io.cdap.cdap.datapipeline.oauth.OAuthRefreshToken;
import io.cdap.cdap.datapipeline.oauth.OAuthStore;
import io.cdap.cdap.datapipeline.oauth.OAuthStoreException;
Expand All @@ -43,6 +44,7 @@
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Optional;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
Expand Down Expand Up @@ -115,6 +117,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
PutOAuthProviderRequest putOAuthProviderRequest = GSON.fromJson(
StandardCharsets.UTF_8.decode(request.getContent()).toString(),
PutOAuthProviderRequest.class);
CredentialEncodingStrategy strategy = putOAuthProviderRequest.getCredentialEncodingStrategy();
// Validate URLs
URL loginURL = new URL(putOAuthProviderRequest.getLoginURL());
URL tokenRefreshURL = new URL(putOAuthProviderRequest.getTokenRefreshURL());
Expand All @@ -132,6 +135,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
.withLoginURL(loginURL.toString())
.withTokenRefreshURL(tokenRefreshURL.toString())
.withClientCredentials(clientCredentials)
.withCredentialEncodingStrategy(strategy)
.build();
oauthStore.writeProvider(provider, reuseClientCredentials);
responder.sendStatus(HttpURLConnection.HTTP_OK);
Expand Down Expand Up @@ -310,13 +314,30 @@ private boolean checkCredIsValid(HttpResponse response) throws OAuthServiceExcep
private HttpRequest createGetRefreshTokenRequest(OAuthProvider provider, String code, String redirectURI)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(String.format(
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();

String body;
switch (strategy) {
case BASIC_AUTH:
body = String.format("code=%s&redirect_uri=%s&grant_type=authorization_code", code, redirectURI);
break;
case FORM_BODY: // fall-through
default:
body = String.format(
"code=%s&redirect_uri=%s&client_id=%s&client_secret=%s&grant_type=authorization_code",
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret()))
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED)
.build();
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret());
}

try {
HttpRequest.Builder requestBuilder = HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(body)
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED);

if (strategy == CredentialEncodingStrategy.BASIC_AUTH) {
requestBuilder.addHeader(HttpHeaders.AUTHORIZATION, getBasicAuthHeader(clientCreds));
}

return requestBuilder.build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
Expand All @@ -325,19 +346,40 @@ private HttpRequest createGetRefreshTokenRequest(OAuthProvider provider, String
private HttpRequest createGetAccessTokenRequest(OAuthProvider provider, String refreshToken)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(
String.format("grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s",
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();

String body;
switch (strategy) {
case BASIC_AUTH:
body = String.format("grant_type=refresh_token&refresh_token=%s", refreshToken);
break;
case FORM_BODY: // fall-through
default:
body = String.format("grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s",
clientCreds.getClientId(),
clientCreds.getClientSecret(),
refreshToken))
.build();
refreshToken);
}

try {
HttpRequest.Builder requestBuilder = HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(body);

if (strategy == CredentialEncodingStrategy.BASIC_AUTH) {
requestBuilder.addHeader(HttpHeaders.AUTHORIZATION, getBasicAuthHeader(clientCreds));
}

return requestBuilder.build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
}

private String getBasicAuthHeader(OAuthClientCredentials clientCreds) {
String authInfo = String.format("%s:%s", clientCreds.getClientId(), clientCreds.getClientSecret());
return String.format("Basic %s", Base64.getEncoder().encode(authInfo.getBytes()));
}

private OAuthProvider getProvider(String provider) throws OAuthServiceException {
try {
Optional<OAuthProvider> providerOptional = oauthStore.getProvider(provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

package io.cdap.cdap.datapipeline;

import com.google.common.collect.Multimap;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.common.http.DefaultHttpRequestConfig;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider;
import io.cdap.cdap.datapipeline.oauth.PutOAuthProviderRequest;
import io.cdap.cdap.datapipeline.oauth.PutOAuthCredentialRequest;
import io.cdap.common.http.HttpMethod;
import io.cdap.common.http.HttpRequest;
import io.cdap.common.http.HttpRequests;
Expand Down Expand Up @@ -104,6 +107,39 @@ public void testCreateProviderWithReuseClientCredentialsFalse() throws IOExcepti
Assert.assertEquals(400, createResponse.getResponseCode());
}

@Test
public void testCreateProviderWithBasicAuth() throws IOException {
// Attempt to create provider
String loginURL = "http://www.example.com/login";
String tokenRefreshURL = "http://www.example.com/token";
String clientId = "clientid";
String clientSecret = "clientsecret";
PutOAuthProviderRequest request = new PutOAuthProviderRequest(
loginURL,
tokenRefreshURL,
clientId,
clientSecret,
OAuthProvider.CredentialEncodingStrategy.BASIC_AUTH);
HttpResponse createResponse = makePutCall("provider/testprovider", request);
Assert.assertEquals(200, createResponse.getResponseCode());

// Grab OAuth login URL to verify write succeeded
HttpResponse getResponse = makeGetCall("provider/testprovider/authurl");
Assert.assertEquals(200, getResponse.getResponseCode());
String authURL = getResponse.getResponseBodyAsString();
Assert.assertEquals("http://www.example.com/login?client_id=clientid&redirect_uri=null", authURL);

PutOAuthCredentialRequest credentialRequest = new PutOAuthCredentialRequest("oneTimeCode", "redirectURI");

// Grab OAuth login URL to verify write succeeded
HttpResponse credentialPutResponse = makePutCall(
"provider/testprovider/credential/credential_id_1234",
credentialRequest);
Assert.assertEquals(200, credentialPutResponse.getResponseCode());
String credentialPutUrl = credentialPutResponse.getResponseBodyAsString();
Assert.assertEquals("http://www.example.com/login?client_id=clientid&redirect_uri=null", credentialPutUrl);
}

@Test
public void testGetAuthURLForMissingClientCredentials() throws IOException {
// Attempt to create provider with missing client credentials and 'reuse_client_credentials'
Expand Down
8 changes: 0 additions & 8 deletions cdap-app-templates/cdap-program-report/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@
<app.main.class>io.cdap.cdap.report.ReportGenerationApp</app.main.class>
</properties>

<repositories>
<repository>
<id>scala-tools.org</id>
<name>Scala-tools Maven2 Repository</name>
<url>https://scala-tools.org/repo-releases</url>
</repository>
</repositories>

<dependencies>
<dependency>
<groupId>com.google.guava</groupId>
Expand Down

0 comments on commit c7c47f0

Please sign in to comment.