Skip to content

Commit

Permalink
Allow AWS services to get temporary credentials from service account (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinlu3 authored Jan 10, 2024
1 parent 715db38 commit 8a06eca
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 50 deletions.
11 changes: 8 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@
<artifactId>javax.annotation-api</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk</artifactId>
<version>1.11.163</version>
<groupId>software.amazon.awssdk</groupId>
<artifactId>aws-sdk-java</artifactId>
<version>2.22.7</version>
</dependency>
<dependency>
<groupId>org.json</groupId>
Expand Down Expand Up @@ -402,6 +402,11 @@
<artifactId>commons-text</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>org.jsoup</groupId>
<artifactId>jsoup</artifactId>
<version>1.14.3</version>
</dependency>
</dependencies>

<build>
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/mskcc/cbio/oncokb/config/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public final class Constants {
public static final String ONCOKB_TM = "OncoKB™";

public static final String TESTING_TOKEN = "faketoken";

public static final String ONCOKB_S3_BUCKET = "oncokb-v2";

private Constants() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import org.apache.commons.lang3.StringUtils;
import org.mskcc.cbio.oncokb.domain.enumeration.ProjectProfile;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.mskcc.oncokb.meta.model.application.RedisProperties;
import org.mskcc.oncokb.meta.model.application.AWSProperties;
import org.mskcc.oncokb.meta.model.application.RedisProperties;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -33,7 +33,7 @@ public class ApplicationProperties {
private String tokenUsageCheckWhitelist;
private int publicWebsiteApiThreshold;
private FrontendProperties frontend;
private AWSProperties aws;
private SamlAwsProperties samlAws;
private String githubToken;
private Boolean dbReadOnly;
private SmartsheetProperties smartsheet;
Expand Down Expand Up @@ -158,12 +158,12 @@ private List<String> getList(String listStr) {
return Arrays.stream(listStr.split(",")).map(element -> element.trim()).filter(element -> !StringUtils.isEmpty(element)).collect(Collectors.toList());
}

public AWSProperties getAws() {
return aws;
public SamlAwsProperties getSamlAws() {
return this.samlAws;
}

public void setAws(AWSProperties aws) {
this.aws = aws;
public void setSamlAws(SamlAwsProperties samlAws) {
this.samlAws = samlAws;
}

public String getGithubToken() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.mskcc.cbio.oncokb.config.application;

public class SamlAwsProperties {

private String serviceAccountUsername;
private String serviceAccountPassword;
private String principalArn; // The Amazon Resource Name (ARN) of the SAML provider in IAM that describes the IdP.
private String roleArn; // The Amazon Resource Name (ARN) of the role that the caller is assuming.
private String region = "us-east-1";


public String getServiceAccountUsername() {
return this.serviceAccountUsername;
}

public void setServiceAccountUsername(String serviceAccountUsername) {
this.serviceAccountUsername = serviceAccountUsername;
}

public String getServiceAccountPassword() {
return this.serviceAccountPassword;
}

public void setServiceAccountPassword(String serviceAccountPassword) {
this.serviceAccountPassword = serviceAccountPassword;
}

public String getPrincipalArn() {
return this.principalArn;
}

public void setPrincipalArn(String principalArn) {
this.principalArn = principalArn;
}

public String getRoleArn() {
return this.roleArn;
}

public void setRoleArn(String roleArn) {
this.roleArn = roleArn;
}

public String getRegion() {
return this.region;
}

public void setRegion(String region) {
this.region = region;
}

}
57 changes: 36 additions & 21 deletions src/main/java/org/mskcc/cbio/oncokb/service/S3Service.java
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
package org.mskcc.cbio.oncokb.service;

import java.io.File;
import java.nio.file.Paths;
import java.util.Optional;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.S3Object;

import org.mskcc.cbio.oncokb.config.application.ApplicationProperties;
import org.mskcc.cbio.oncokb.config.application.SamlAwsProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

@Service
public class S3Service {

private final Logger log = LoggerFactory.getLogger(S3Service.class);

private final SamlService samlService;

private final ApplicationProperties applicationProperties;

private AWSCredentials credentials;
private AmazonS3 s3client;
private S3Client s3Client;

public S3Service(ApplicationProperties applicationProperties){
public S3Service(SamlService samlService, ApplicationProperties applicationProperties) {
this.samlService = samlService;
this.applicationProperties = applicationProperties;

if (applicationProperties.getAws() != null) {
String s3AccessKey = applicationProperties.getAws().getS3AccessKey();
String s3SecretKey = applicationProperties.getAws().getS3SecretKey();
String s3Region = applicationProperties.getAws().getS3Region();
credentials = new BasicAWSCredentials(s3AccessKey, s3SecretKey);
s3client = AmazonS3ClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(credentials)).withRegion(s3Region).build();
SamlAwsProperties samlAwsProperties = applicationProperties.getSamlAws();
if (samlAwsProperties != null) {
String region = samlAwsProperties.getRegion();
s3Client = S3Client.builder().credentialsProvider(samlService.getCredentialsProvider()).region(Region.of(region)).build();
} else {
log.error("Saml AWS properties not configured");
}
}

Expand All @@ -41,7 +47,12 @@ public S3Service(ApplicationProperties applicationProperties){
* @param file the object
*/
public void saveObject(String bucket, String objectPath, File file){
s3client.putObject(bucket, objectPath, file);
PutObjectRequest putObjectRequest = PutObjectRequest
.builder()
.bucket(bucket)
.key(objectPath)
.build();
s3Client.putObject(putObjectRequest, Paths.get(file.getPath()));
}

/**
Expand All @@ -50,11 +61,15 @@ public void saveObject(String bucket, String objectPath, File file){
* @param objectPath the path of the object
* @return a S3 object
*/
public Optional<S3Object> getObject(String bucket, String objectPath){
public Optional<ResponseInputStream<GetObjectResponse>> getObject(String bucket, String objectPath){
try {
S3Object s3object = s3client.getObject(bucket, objectPath);
ResponseInputStream<GetObjectResponse> s3object = s3Client.getObject(GetObjectRequest.builder()
.bucket(bucket)
.key(objectPath)
.build(), ResponseTransformer.toInputStream());
return Optional.of(s3object);
} catch (Exception e) {
log.error(e.getMessage(), e);
return Optional.empty();
}
}
Expand Down
105 changes: 105 additions & 0 deletions src/main/java/org/mskcc/cbio/oncokb/service/SamlService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package org.mskcc.cbio.oncokb.service;

import java.util.function.Supplier;
import javax.annotation.PostConstruct;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.mskcc.cbio.oncokb.config.application.ApplicationProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleWithSamlCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest;

@Service
public class SamlService {
private final Logger log = LoggerFactory.getLogger(SamlService.class);

private StsAssumeRoleWithSamlCredentialsProvider credentialsProvider;

private final String MSK_IDP_URL ="https://ssofed.mskcc.org/idp/startSSO.ping?PartnerSpId=urn:amazon:webservices";
private final String MSK_USERNAME_FIELD = "pf.username";
private final String MSK_PASSWORD_FIELD = "pf.pass";
private final Integer SESSION_DURATION_IN_SECONDS = 28800; // 8 hours, the maximum allowable

private final ApplicationProperties applicationProperties;

public SamlService(ApplicationProperties applicationProperties) {
this.applicationProperties = applicationProperties;
}

@PostConstruct
private void initSecurityTokenServiceProvider() {
if (applicationProperties.getSamlAws() != null) {
AssumeRoleWithSamlRequest samlRequest = AssumeRoleWithSamlRequest
.builder()
.principalArn(applicationProperties.getSamlAws().getPrincipalArn())
.roleArn(applicationProperties.getSamlAws().getRoleArn())
.durationSeconds(SESSION_DURATION_IN_SECONDS)
.build();

Supplier<AssumeRoleWithSamlRequest> supplier = () ->
samlRequest.toBuilder().samlAssertion(getSamlResponse()).build();

String region = applicationProperties.getSamlAws().getRegion();

StsAssumeRoleWithSamlCredentialsProvider stsProvider = StsAssumeRoleWithSamlCredentialsProvider
.builder()
.stsClient(StsClient.builder().credentialsProvider(AnonymousCredentialsProvider.create()).region(Region.of(region)).build())
.refreshRequest(supplier)
.build();
credentialsProvider = stsProvider;
} else {
log.warn("Saml AWS properties not configured");
}
}

public StsAssumeRoleWithSamlCredentialsProvider getCredentialsProvider() {
return this.credentialsProvider;
}

private String getSamlResponse() throws RuntimeException {

RestTemplate restTemplate = new RestTemplate();

HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);

MultiValueMap<String, String> map = new LinkedMultiValueMap<String, String>();
map.add(MSK_USERNAME_FIELD, applicationProperties.getSamlAws().getServiceAccountUsername());
map.add(MSK_PASSWORD_FIELD, applicationProperties.getSamlAws().getServiceAccountPassword());

HttpEntity<MultiValueMap<String, String>> request = new HttpEntity<MultiValueMap<String, String>>(
map,
headers
);
ResponseEntity<String> response = restTemplate.postForEntity(
MSK_IDP_URL,
request,
String.class
);

Document document = Jsoup.parse(response.getBody());
Element samlResponseElement = document
.select("input[name=SAMLResponse]")
.first();

if (samlResponseElement == null) {
throw new RuntimeException("Could not find SAMLResponse value in SAML assertion response");
}

return samlResponseElement.attr("value");
}
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
package org.mskcc.cbio.oncokb.web.rest;

import org.apache.commons.lang3.StringUtils;
import org.mskcc.cbio.oncokb.config.Constants;
import org.mskcc.cbio.oncokb.config.application.ApplicationProperties;
import org.mskcc.cbio.oncokb.domain.*;
import org.mskcc.cbio.oncokb.domain.enumeration.FileExtension;
import org.mskcc.cbio.oncokb.querydomain.UserTokenUsage;
import org.mskcc.cbio.oncokb.querydomain.UserTokenUsageWithInfo;
import org.mskcc.cbio.oncokb.repository.UserDetailsRepository;
import org.mskcc.cbio.oncokb.security.AuthoritiesConstants;
import org.mskcc.cbio.oncokb.security.uuid.TokenProvider;
import org.mskcc.cbio.oncokb.service.*;
import org.mskcc.cbio.oncokb.service.dto.UserDTO;
import org.mskcc.cbio.oncokb.service.mapper.UserMapper;
import org.mskcc.cbio.oncokb.web.rest.vm.ExposedToken;
import org.mskcc.cbio.oncokb.web.rest.vm.usageAnalysis.ResourceModel;
import org.mskcc.cbio.oncokb.web.rest.vm.usageAnalysis.UsageSummary;
import org.mskcc.cbio.oncokb.web.rest.vm.usageAnalysis.UserUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -29,7 +27,7 @@
import java.util.*;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.services.s3.S3Client;

import java.io.*;
import java.util.zip.ZipEntry;
Expand Down Expand Up @@ -172,13 +170,13 @@ public void moveTokenStatsToS3() throws IOException {
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd");
String dateWrapped = dateFormat.format(dateFormat.parse(tokenUsageDateBefore.minus(1, ChronoUnit.DAYS).toString()));
String datedFile = TOKEN_STATS_STORAGE_FILE_PREFIX + dateWrapped + FileExtension.ZIPPED_FILE.getExtension();
if (s3Service.getObject("oncokb", datedFile).isPresent()) {
if (s3Service.getObject(Constants.ONCOKB_S3_BUCKET, datedFile).isPresent()) {
log.info("Token stats have already been wrapped today. Skipping this request.");
} else {
// Update tokenStats in database
updateTokenUsage(tokenUsageDateBefore);
// Send tokenStats to s3
s3Service.saveObject("oncokb", datedFile, createWrappedFile(tokenUsageDateBefore, dateWrapped + FileExtension.TEXT_FILE.getExtension()));
s3Service.saveObject(Constants.ONCOKB_S3_BUCKET, datedFile, createWrappedFile(tokenUsageDateBefore, dateWrapped + FileExtension.TEXT_FILE.getExtension()));
// Delete old tokenStats
tokenStatsService.clearTokenStats(tokenUsageDateBefore);
}
Expand Down
Loading

0 comments on commit 8a06eca

Please sign in to comment.