Skip to content

Commit

Permalink
feat(oauth): add a task to retrieve azure access token (#145)
Browse files Browse the repository at this point in the history
created task to retrieve token from credentials

created unit tests (disabled by default)

refactored code to extract azure identity credentials

#131
  • Loading branch information
mgabelle authored Oct 9, 2024
1 parent 958dc10 commit 5b8a385
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package io.kestra.plugin.azure.datafactory;
package io.kestra.plugin.azure;

import com.azure.core.credential.TokenCredential;
import com.azure.core.management.AzureEnvironment;
import com.azure.core.management.profile.AzureProfile;
import com.azure.identity.*;
import com.azure.resourcemanager.datafactory.DataFactoryManager;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.Task;
Expand All @@ -17,16 +14,20 @@
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;

/**
* This class enables the creation of Azure credentials from different sources.
* For more information please refer to the <a href="https://learn.microsoft.com/en-us/java/api/overview/azure/identity-readme?view=azure-java-stable">Azure Identity documentation</a>
*
*/
@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
public abstract class AbstractDataFactoryConnection extends Task implements AbstractDataFactoryConnectionInterface {
public abstract class AbstractAzureIdentityConnection extends Task implements AzureIdentityConnectionInterface {
@NotNull
protected Property<String> subscriptionId;
@NotNull
protected Property<String> tenantId;
@Builder.Default
protected Property<String> tenantId = Property.of("");

@Builder.Default
protected Property<String> clientId = Property.of("");
Expand All @@ -35,43 +36,30 @@ public abstract class AbstractDataFactoryConnection extends Task implements Abst
@Builder.Default
protected Property<String> pemCertificate = Property.of("");

protected DataFactoryManager getDataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException {
runContext.logger().info("Authenticating to Azure Data Factory");
return DataFactoryManager.authenticate(credentials(runContext), profile(runContext));
}

private AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException {
public TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException {
final String tenantId = this.tenantId.as(runContext, String.class);
final String subscriptionId = this.subscriptionId.as(runContext, String.class);

return new AzureProfile(
tenantId,
subscriptionId,
AzureEnvironment.AZURE
);
}

private TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException {
final String tenantId = runContext.render(this.tenantId.as(runContext, String.class));
final String clientId = runContext.render(this.clientId.as(runContext, String.class));
final String clientId = this.clientId.as(runContext, String.class);

//Create client/secret credentials
final String clientSecret = runContext.render(this.clientSecret.as(runContext, String.class));
final String clientSecret = this.clientSecret.as(runContext, String.class);
if(StringUtils.isNotBlank(clientSecret)) {
runContext.logger().info("Authentication is using Client Secret Credentials");
return getClientSecretCredential(tenantId, clientId, clientSecret);
}

//Create client/certificate credentials
final String pemCertificate = runContext.render(this.pemCertificate.as(runContext, String.class));
final String pemCertificate = this.pemCertificate.as(runContext, String.class);
if(StringUtils.isNotBlank(pemCertificate)) {
runContext.logger().info("Authentication is using Client Certificate Credentials");
return getClientCertificateCredential(tenantId, clientId, pemCertificate);
}

//Create default authentication
runContext.logger().info("Authentication is using Default Azure Credentials");
return new DefaultAzureCredentialBuilder().tenantId(tenantId).build();
}

private ClientCertificateCredential getClientCertificateCredential(String clientId, String tenantId, String pemCertificate) {
private ClientCertificateCredential getClientCertificateCredential(String tenantId, String clientId, String pemCertificate) {
return new ClientCertificateCredentialBuilder()
.clientId(clientId)
.tenantId(tenantId)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package io.kestra.plugin.azure.datafactory;
package io.kestra.plugin.azure;

import io.kestra.core.models.property.Property;
import io.swagger.v3.oas.annotations.media.Schema;

public interface AbstractDataFactoryConnectionInterface {
public interface AzureIdentityConnectionInterface {
@Schema(
title = "Client ID",
description = """
Expand Down Expand Up @@ -33,7 +33,4 @@ public interface AbstractDataFactoryConnectionInterface {

@Schema(title = "Tenant ID")
Property<String> getTenantId();

@Schema(title = "Subscription ID")
Property<String> getSubscriptionId();
}
95 changes: 95 additions & 0 deletions src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package io.kestra.plugin.azure.auth;

import com.azure.core.credential.*;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.models.tasks.common.EncryptedString;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.azure.AbstractAzureIdentityConnection;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;

import java.time.OffsetDateTime;
import java.util.Collections;
import java.util.List;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Plugin(
examples = {
@Example(
full = true,
code = """
id: azure_get_token
namespace: company.team
tasks:
- id: get_access_token
type: io.kestra.plugin.azure.oauth.OauthAccessToken
tenantId: "{{ secret('SERVICE_PRINCIPAL_TENANT_ID') }}"
clientId: "{{ secret('SERVICE_PRINCIPAL_CLIENT_ID') }}"
clientSecret: "{{ secret('SERVICE_PRINCIPAL_CLIENT_SECRET') }}"
"""
)
}
)
@Schema(
title = "Fetch an OAuth access token."
)
public class OauthAccessToken extends AbstractAzureIdentityConnection implements RunnableTask<OauthAccessToken.Output> {
@Schema(title = "The Azure scopes to be used")
@Builder.Default
Property<List<String>> scopes = Property.of(Collections.singletonList("https://management.azure.com/.default"));

@Override
public Output run(RunContext runContext) throws Exception {
TokenCredential credential = this.credentials(runContext);

TokenRequestContext requestContext = new TokenRequestContext();
requestContext.setScopes(scopes.asList(runContext, String.class));

runContext.logger().info("Retrieve access token.");
AccessToken accessToken = credential.getTokenSync(requestContext);

runContext.logger().info("Successfully retrieved access token.");

var output = AccessTokenOutput.builder()
.expirationTime(accessToken.getExpiresAt())
.scopes(requestContext.getScopes())
.tokenValue(EncryptedString.from(accessToken.getToken(), runContext));

return Output
.builder()
.accessToken(output.build())
.build();
}

@Builder
@Getter
public static class Output implements io.kestra.core.models.tasks.Output {
@NotNull
@Schema(title = "An OAuth access token for the current user.")
private final AccessTokenOutput accessToken;
}

@Builder
@Getter
public static class AccessTokenOutput {
List<String> scopes;

@Schema(
title = "OAuth access token value",
description = "Will be automatically encrypted and decrypted in the outputs if encryption is configured"
)
EncryptedString tokenValue;

OffsetDateTime expirationTime;
}
}
8 changes: 8 additions & 0 deletions src/main/java/io/kestra/plugin/azure/auth/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@PluginSubGroup(
title = "Authentication",
description = "This sub-group of plugins contains tasks to manage authentication for Azure services.",
categories = { PluginSubGroup.PluginCategory.CLOUD }
)
package io.kestra.plugin.azure.auth;

import io.kestra.core.models.annotations.PluginSubGroup;
37 changes: 31 additions & 6 deletions src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.kestra.plugin.azure.datafactory;

import com.azure.core.management.AzureEnvironment;
import com.azure.core.management.profile.AzureProfile;
import com.azure.core.util.Context;
import com.azure.resourcemanager.datafactory.DataFactoryManager;
import com.azure.resourcemanager.datafactory.models.ActivityRun;
Expand All @@ -11,6 +13,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.executions.metrics.Counter;
Expand All @@ -21,8 +24,10 @@
import io.kestra.core.serializers.FileSerde;
import io.kestra.core.serializers.JacksonMapper;
import io.kestra.core.utils.Await;
import io.kestra.plugin.azure.AbstractAzureIdentityConnection;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -76,12 +81,16 @@
description = "Launch an Azure DataFactory pipeline from Kestra. " +
"Data Factory contains a series of interconnected systems that provide a complete end-to-end platform for data engineers."
)
public class CreateRun extends AbstractDataFactoryConnection implements RunnableTask<CreateRun.Output> {
public class CreateRun extends AbstractAzureIdentityConnection implements RunnableTask<CreateRun.Output> {
private static final String PIPELINE_SUCCEEDED_STATUS = "Succeeded";
private static final List<String> PIPELINE_FAILED_STATUS = List.of("Failed", "Canceling", "Cancelled");
private static final Duration WAIT_UNTIL_COMPLETION = Duration.ofHours(1);
private static final Duration COMPLETION_CHECK_INTERVAL = Duration.ofSeconds(5);

@Schema(title = "Subscription ID")
@NotNull
protected Property<String> subscriptionId;

@Schema(title = "Factory name")
private Property<String> factoryName;

Expand Down Expand Up @@ -109,7 +118,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();

//Authentication
DataFactoryManager manager = this.getDataFactoryManager(runContext);
DataFactoryManager manager = this.dataFactoryManager(runContext);
logger.info("Successfully authenticate to Azure Data Factory");

//Create running pipeline
Expand Down Expand Up @@ -146,7 +155,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
final AtomicReference<PipelineRun> runningPipelineResponse = new AtomicReference<>();
try {
Await.until(() -> {
runningPipelineResponse.set(getRunningPipeline(resourceGroupName,factoryName, runId, manager));
runningPipelineResponse.set(runningPipeline(resourceGroupName,factoryName, runId, manager));
String runStatus = runningPipelineResponse.get().status();

if (PIPELINE_FAILED_STATUS.contains(runStatus)) {
Expand Down Expand Up @@ -192,7 +201,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (var output = new BufferedWriter(new FileWriter(tempFile))) {
var flux = Flux.fromIterable(activities);
Mono<Long> longMono = FileSerde.writeAll(getIonMapper(), output, flux);
Mono<Long> longMono = FileSerde.writeAll(ionMapper(), output, flux);
Long count = longMono.blockOptional().orElse(0L);

runContext.metric(Counter.of("activities", count));
Expand All @@ -216,11 +225,27 @@ public static class Output implements io.kestra.core.models.tasks.Output {
private URI uri;
}

private PipelineRun getRunningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) {
private DataFactoryManager dataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException {
runContext.logger().info("Authenticating to Azure Data Factory");
return DataFactoryManager.authenticate(credentials(runContext), profile(runContext));
}

public AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException {
final String tenantId = this.tenantId.as(runContext, String.class);
final String subscriptionId = this.subscriptionId.as(runContext, String.class);

return new AzureProfile(
tenantId,
subscriptionId,
AzureEnvironment.AZURE
);
}

private PipelineRun runningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) {
return manager.pipelineRuns().get(resourceGroupName, factoryName, runId);
}

private static ObjectMapper getIonMapper() {
private static ObjectMapper ionMapper() {
ObjectMapper ionMapper = new ObjectMapper(JacksonMapper.ofIon().getFactory());
ionMapper.setSerializationInclusion(JsonInclude.Include.ALWAYS);
ionMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
Expand Down
23 changes: 23 additions & 0 deletions src/main/resources/icons/io.kestra.plugin.azure.auth.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 5b8a385

Please sign in to comment.