Skip to content

Add ClientCredentialsTokenProvider #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ dependencies {
implementation 'com.palantir.safe-logging:logger'
implementation 'com.palantir.safe-logging:preconditions'

testImplementation 'org.assertj:assertj-core'
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.mockito:mockito-core'
testImplementation 'org.mockito:mockito-junit-jupiter'
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.palantir.computemodules.functions.serde.DefaultDeserializer;
import com.palantir.computemodules.functions.serde.DefaultSerializer;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.Unsafe;
import com.palantir.logsafe.exceptions.SafeRuntimeException;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
Expand Down Expand Up @@ -90,6 +91,7 @@ public void onFailure(Throwable throwable) {
}
}

@Unsafe
private Result execute(ComputeModuleJob job) {
if (functions.containsKey(job.queryType())) {
return functions.get(job.queryType()).run(new Context(job.jobId()), job.query());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* (c) Copyright 2025 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.palantir.computemodules.auth;

import com.fasterxml.jackson.annotation.JsonProperty;

public record AuthTokenResponse(
@JsonProperty("access_token") String accessToken,
@JsonProperty("scope") String scope,
@JsonProperty("expires_in") Integer expiresIn,
@JsonProperty("token_type") String tokenType) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* (c) Copyright 2025 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.palantir.computemodules.auth;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.exceptions.SafeIllegalArgumentException;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.StringJoiner;
import java.util.function.Supplier;
import javax.annotation.Nullable;

public final class ClientCredentialsTokenProvider implements Supplier<String> {
private static final SafeLogger log = SafeLoggerFactory.get(ClientCredentialsTokenProvider.class);
private static final Duration refreshInterval = Duration.ofHours(1);
private static final String oauthTokenPath = "/multipass/api/oauth2/token";
private static final ObjectMapper mapper = new ObjectMapper();

private final HttpClient client;
private Instant lastRefreshed;
private Optional<AuthTokenResponse> maybeTokenResponse;

private final String hostname;
private final String clientId;
private final String clientSecret;
private final List<String> scopes;

private ClientCredentialsTokenProvider(
String hostname, String clientId, String clientSecret, List<String> scopes, HttpClient client) {
this.hostname = extractHost(hostname);
this.clientId = clientId;
this.clientSecret = clientSecret;
this.scopes = scopes;
this.client = client;
lastRefreshed = Instant.now();
maybeTokenResponse = Optional.empty();
}

public static ClientCredentialsTokenProviderBuilder builder() {
return new ClientCredentialsTokenProviderBuilder();
}

@Override
public synchronized String get() {
if (shouldRefreshToken()) {
refreshToken();
}
return maybeTokenResponse.map(AuthTokenResponse::accessToken).orElseThrow();
}

public static Optional<String> getClientId() {
String clientId = System.getenv("CLIENT_ID");
return Optional.ofNullable(clientId);
}

public static Optional<String> getClientSecret() {
String clientSecret = System.getenv("CLIENT_SECRET");
return Optional.ofNullable(clientSecret);
}

private static String extractHost(String host) {
try {
URL url = new URL(host);
return url.getHost();
} catch (MalformedURLException e) {
return host;
}
}

private boolean shouldRefreshToken() {
if (maybeTokenResponse.isEmpty()) {
return true;
}
Instant now = Instant.now();
Duration timeSinceLastRefresh = Duration.between(lastRefreshed, now);
return timeSinceLastRefresh.compareTo(refreshInterval) > 0;
}

private String buildFormParams() {
StringJoiner joiner = new StringJoiner("&");
joiner.add("grant_type=" + URLEncoder.encode("client_credentials", StandardCharsets.UTF_8));
joiner.add("client_id=" + URLEncoder.encode(clientId, StandardCharsets.UTF_8));
joiner.add("client_secret=" + URLEncoder.encode(clientSecret, StandardCharsets.UTF_8));
joiner.add("scope=" + URLEncoder.encode(String.join(" ", scopes), StandardCharsets.UTF_8));

return joiner.toString();
}

private void refreshToken() {
// TODO(sk): might want to retry
try {
lastRefreshed = Instant.now();
HttpRequest request = HttpRequest.newBuilder()
.uri(new URI("https", hostname, oauthTokenPath, null))
.header("Content-Type", "application/x-www-form-urlencoded")
.POST(HttpRequest.BodyPublishers.ofString(buildFormParams()))
.build();
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
if (response.statusCode() == 200) {
maybeTokenResponse = deserialize(response.body());
return;
}
log.warn("Non-200 status code returned from token endpoint", SafeArg.of("response", response));
} catch (URISyntaxException | IOException | InterruptedException e) {
log.error("Exception raised trying to refresh token", e);
}
}

private Optional<AuthTokenResponse> deserialize(String raw) {
try {
AuthTokenResponse response = mapper.readValue(raw, AuthTokenResponse.class);
return Optional.of(response);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

public static final class ClientCredentialsTokenProviderBuilder {
private @Nullable String hostname = null;
private @Nullable String clientId;
private @Nullable String clientSecret;
private final List<String> scopes;
private HttpClient client = HttpClient.newBuilder().build();

private ClientCredentialsTokenProviderBuilder() {
clientId = getClientId().orElse(null);
clientSecret = getClientSecret().orElse(null);
scopes = new ArrayList<>();
}

public ClientCredentialsTokenProviderBuilder hostname(String value) {
hostname = value;
return this;
}

public ClientCredentialsTokenProviderBuilder clientId(String value) {
clientId = value;
return this;
}

public ClientCredentialsTokenProviderBuilder clientSecret(String value) {
clientSecret = value;
return this;
}

public ClientCredentialsTokenProviderBuilder scopes(String... values) {
Collections.addAll(scopes, values);
return this;
}

public ClientCredentialsTokenProviderBuilder client(HttpClient value) {
client = value;
return this;
}

public ClientCredentialsTokenProvider build() {
if (hostname == null) {
throw new SafeIllegalArgumentException("hostname must be set");
}
if (clientId == null) {
throw new SafeIllegalArgumentException("clientId must be set");
}
if (clientSecret == null) {
throw new SafeIllegalArgumentException("clientSecret must be set");
}
return new ClientCredentialsTokenProvider(hostname, clientId, clientSecret, scopes, client);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.palantir.computemodules.functions.results.Result;
import com.palantir.computemodules.functions.serde.Deserializer;
import com.palantir.computemodules.functions.serde.Serializer;
import com.palantir.logsafe.Unsafe;
import java.io.InputStream;

public final class FunctionRunner<I, O> {
Expand All @@ -42,6 +43,7 @@ public FunctionRunner(
this.serializer = serializer;
}

@Unsafe
public Result run(Context context, Object input) {
I deserializedInput = deserializer.deserialize(input, inputType);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@
*/
package com.palantir.computemodules.functions.results;

import com.palantir.logsafe.Unsafe;

@Unsafe
public sealed interface Result permits Ok, Failed {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* (c) Copyright 2025 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.palantir.computemodules;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.palantir.computemodules.auth.AuthTokenResponse;
import com.palantir.computemodules.auth.ClientCredentialsTokenProvider;
import java.io.IOException;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.NoSuchElementException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
public class ClientCredentialsTokenProviderTest {
private static final String CLIENT_ID = "client_id";
private static final String CLIENT_SECRET = "client_secret";
private static final String hostname = "foundry.stack.com";
private static final String hostnameWithScheme = "https://" + hostname;
private static final String expectedTokenUrl = hostnameWithScheme + "/multipass/api/oauth2/token";
private static final String scope1 = "my-app:view";
private static final String scope2 = "my-app:write";
private static final AuthTokenResponse mockOkTokenResponse =
new AuthTokenResponse("dummy_token", scope1 + " " + scope2, 3600, "Bearer");
private static final ObjectMapper objectMapper = new ObjectMapper();

@Mock
private HttpClient client;

@Mock
private HttpResponse<String> response;

private ClientCredentialsTokenProvider fixture;

@Test
void get_provides_access_token() throws IOException, InterruptedException {
fixture = ClientCredentialsTokenProvider.builder()
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.hostname(hostname)
.scopes(scope1, scope2)
.client(client)
.build();

when(client.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(response);
when(response.statusCode()).thenReturn(200);
when(response.body()).thenReturn(objectMapper.writeValueAsString(mockOkTokenResponse));
String token = fixture.get();
assertNotNull(token);
assertEquals(token, "dummy_token");
ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
Mockito.verify(client).send(requestCaptor.capture(), any(HttpResponse.BodyHandler.class));
HttpRequest capturedRequest = requestCaptor.getValue();
assertEquals(capturedRequest.uri().toString(), expectedTokenUrl);
}

@Test
void provider_works_with_scheme_in_hostname() throws IOException, InterruptedException {
fixture = ClientCredentialsTokenProvider.builder()
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.hostname(hostnameWithScheme)
.scopes(scope1, scope2)
.client(client)
.build();

when(client.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(response);
when(response.statusCode()).thenReturn(200);
when(response.body()).thenReturn(objectMapper.writeValueAsString(mockOkTokenResponse));
String token = fixture.get();
assertNotNull(token);
assertEquals(token, "dummy_token");
ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
Mockito.verify(client).send(requestCaptor.capture(), any(HttpResponse.BodyHandler.class));
HttpRequest capturedRequest = requestCaptor.getValue();
assertEquals(capturedRequest.uri().toString(), expectedTokenUrl);
}

@Test
void throws_when_token_response_non_200() throws IOException, InterruptedException {
fixture = ClientCredentialsTokenProvider.builder()
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.hostname(hostname)
.scopes(scope1, scope2)
.client(client)
.build();
when(client.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)))
.thenReturn(response);
when(response.statusCode()).thenReturn(400);
assertThatThrownBy(() -> fixture.get()).isInstanceOf(NoSuchElementException.class);
}
}
Loading