diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java b/src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java
similarity index 56%
rename from src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java
rename to src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java
index 514ba27..430187c 100644
--- a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java
+++ b/src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java
@@ -1,10 +1,7 @@
-package io.kestra.plugin.azure.datafactory;
+package io.kestra.plugin.azure;
import com.azure.core.credential.TokenCredential;
-import com.azure.core.management.AzureEnvironment;
-import com.azure.core.management.profile.AzureProfile;
import com.azure.identity.*;
-import com.azure.resourcemanager.datafactory.DataFactoryManager;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.Task;
@@ -17,16 +14,20 @@
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
+/**
+ * This class enables the creation of Azure credentials from different sources.
+ * For more information please refer to the Azure Identity documentation
+ *
+ */
@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
-public abstract class AbstractDataFactoryConnection extends Task implements AbstractDataFactoryConnectionInterface {
+public abstract class AbstractAzureIdentityConnection extends Task implements AzureIdentityConnectionInterface {
@NotNull
- protected Property subscriptionId;
- @NotNull
- protected Property tenantId;
+ @Builder.Default
+ protected Property tenantId = Property.of("");
@Builder.Default
protected Property clientId = Property.of("");
@@ -35,43 +36,30 @@ public abstract class AbstractDataFactoryConnection extends Task implements Abst
@Builder.Default
protected Property pemCertificate = Property.of("");
- protected DataFactoryManager getDataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException {
- runContext.logger().info("Authenticating to Azure Data Factory");
- return DataFactoryManager.authenticate(credentials(runContext), profile(runContext));
- }
-
- private AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException {
+ public TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException {
final String tenantId = this.tenantId.as(runContext, String.class);
- final String subscriptionId = this.subscriptionId.as(runContext, String.class);
-
- return new AzureProfile(
- tenantId,
- subscriptionId,
- AzureEnvironment.AZURE
- );
- }
-
- private TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException {
- final String tenantId = runContext.render(this.tenantId.as(runContext, String.class));
- final String clientId = runContext.render(this.clientId.as(runContext, String.class));
+ final String clientId = this.clientId.as(runContext, String.class);
//Create client/secret credentials
- final String clientSecret = runContext.render(this.clientSecret.as(runContext, String.class));
+ final String clientSecret = this.clientSecret.as(runContext, String.class);
if(StringUtils.isNotBlank(clientSecret)) {
+ runContext.logger().info("Authentication is using Client Secret Credentials");
return getClientSecretCredential(tenantId, clientId, clientSecret);
}
//Create client/certificate credentials
- final String pemCertificate = runContext.render(this.pemCertificate.as(runContext, String.class));
+ final String pemCertificate = this.pemCertificate.as(runContext, String.class);
if(StringUtils.isNotBlank(pemCertificate)) {
+ runContext.logger().info("Authentication is using Client Certificate Credentials");
return getClientCertificateCredential(tenantId, clientId, pemCertificate);
}
//Create default authentication
+ runContext.logger().info("Authentication is using Default Azure Credentials");
return new DefaultAzureCredentialBuilder().tenantId(tenantId).build();
}
- private ClientCertificateCredential getClientCertificateCredential(String clientId, String tenantId, String pemCertificate) {
+ private ClientCertificateCredential getClientCertificateCredential(String tenantId, String clientId, String pemCertificate) {
return new ClientCertificateCredentialBuilder()
.clientId(clientId)
.tenantId(tenantId)
diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java b/src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java
similarity index 87%
rename from src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java
rename to src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java
index 6511eb8..e902e39 100644
--- a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java
+++ b/src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java
@@ -1,9 +1,9 @@
-package io.kestra.plugin.azure.datafactory;
+package io.kestra.plugin.azure;
import io.kestra.core.models.property.Property;
import io.swagger.v3.oas.annotations.media.Schema;
-public interface AbstractDataFactoryConnectionInterface {
+public interface AzureIdentityConnectionInterface {
@Schema(
title = "Client ID",
description = """
@@ -33,7 +33,4 @@ public interface AbstractDataFactoryConnectionInterface {
@Schema(title = "Tenant ID")
Property getTenantId();
-
- @Schema(title = "Subscription ID")
- Property getSubscriptionId();
}
diff --git a/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java b/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java
new file mode 100644
index 0000000..96ffae6
--- /dev/null
+++ b/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java
@@ -0,0 +1,95 @@
+package io.kestra.plugin.azure.auth;
+
+import com.azure.core.credential.*;
+import io.kestra.core.models.annotations.Example;
+import io.kestra.core.models.annotations.Plugin;
+import io.kestra.core.models.property.Property;
+import io.kestra.core.models.tasks.RunnableTask;
+import io.kestra.core.models.tasks.common.EncryptedString;
+import io.kestra.core.runners.RunContext;
+import io.kestra.plugin.azure.AbstractAzureIdentityConnection;
+import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
+import lombok.*;
+import lombok.experimental.SuperBuilder;
+
+import java.time.OffsetDateTime;
+import java.util.Collections;
+import java.util.List;
+
+@SuperBuilder
+@ToString
+@EqualsAndHashCode
+@Getter
+@NoArgsConstructor
+@Plugin(
+ examples = {
+ @Example(
+ full = true,
+ code = """
+ id: azure_get_token
+ namespace: company.team
+
+ tasks:
+ - id: get_access_token
+ type: io.kestra.plugin.azure.oauth.OauthAccessToken
+ tenantId: "{{ secret('SERVICE_PRINCIPAL_TENANT_ID') }}"
+ clientId: "{{ secret('SERVICE_PRINCIPAL_CLIENT_ID') }}"
+ clientSecret: "{{ secret('SERVICE_PRINCIPAL_CLIENT_SECRET') }}"
+ """
+ )
+ }
+)
+@Schema(
+ title = "Fetch an OAuth access token."
+)
+public class OauthAccessToken extends AbstractAzureIdentityConnection implements RunnableTask {
+ @Schema(title = "The Azure scopes to be used")
+ @Builder.Default
+ Property> scopes = Property.of(Collections.singletonList("https://management.azure.com/.default"));
+
+ @Override
+ public Output run(RunContext runContext) throws Exception {
+ TokenCredential credential = this.credentials(runContext);
+
+ TokenRequestContext requestContext = new TokenRequestContext();
+ requestContext.setScopes(scopes.asList(runContext, String.class));
+
+ runContext.logger().info("Retrieve access token.");
+ AccessToken accessToken = credential.getTokenSync(requestContext);
+
+ runContext.logger().info("Successfully retrieved access token.");
+
+ var output = AccessTokenOutput.builder()
+ .expirationTime(accessToken.getExpiresAt())
+ .scopes(requestContext.getScopes())
+ .tokenValue(EncryptedString.from(accessToken.getToken(), runContext));
+
+ return Output
+ .builder()
+ .accessToken(output.build())
+ .build();
+ }
+
+ @Builder
+ @Getter
+ public static class Output implements io.kestra.core.models.tasks.Output {
+ @NotNull
+ @Schema(title = "An OAuth access token for the current user.")
+ private final AccessTokenOutput accessToken;
+ }
+
+ @Builder
+ @Getter
+ public static class AccessTokenOutput {
+ List scopes;
+
+ @Schema(
+ title = "OAuth access token value",
+ description = "Will be automatically encrypted and decrypted in the outputs if encryption is configured"
+ )
+ EncryptedString tokenValue;
+
+ OffsetDateTime expirationTime;
+ }
+}
diff --git a/src/main/java/io/kestra/plugin/azure/auth/package-info.java b/src/main/java/io/kestra/plugin/azure/auth/package-info.java
new file mode 100644
index 0000000..49fc654
--- /dev/null
+++ b/src/main/java/io/kestra/plugin/azure/auth/package-info.java
@@ -0,0 +1,8 @@
+@PluginSubGroup(
+ title = "Authentication",
+ description = "This sub-group of plugins contains tasks to manage authentication for Azure services.",
+ categories = { PluginSubGroup.PluginCategory.CLOUD }
+)
+package io.kestra.plugin.azure.auth;
+
+import io.kestra.core.models.annotations.PluginSubGroup;
\ No newline at end of file
diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java b/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java
index 2b3405e..8b94cae 100644
--- a/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java
+++ b/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java
@@ -1,5 +1,7 @@
package io.kestra.plugin.azure.datafactory;
+import com.azure.core.management.AzureEnvironment;
+import com.azure.core.management.profile.AzureProfile;
import com.azure.core.util.Context;
import com.azure.resourcemanager.datafactory.DataFactoryManager;
import com.azure.resourcemanager.datafactory.models.ActivityRun;
@@ -11,6 +13,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
+import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.executions.metrics.Counter;
@@ -21,8 +24,10 @@
import io.kestra.core.serializers.FileSerde;
import io.kestra.core.serializers.JacksonMapper;
import io.kestra.core.utils.Await;
+import io.kestra.plugin.azure.AbstractAzureIdentityConnection;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.swagger.v3.oas.annotations.media.Schema;
+import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j;
@@ -76,12 +81,16 @@
description = "Launch an Azure DataFactory pipeline from Kestra. " +
"Data Factory contains a series of interconnected systems that provide a complete end-to-end platform for data engineers."
)
-public class CreateRun extends AbstractDataFactoryConnection implements RunnableTask {
+public class CreateRun extends AbstractAzureIdentityConnection implements RunnableTask {
private static final String PIPELINE_SUCCEEDED_STATUS = "Succeeded";
private static final List PIPELINE_FAILED_STATUS = List.of("Failed", "Canceling", "Cancelled");
private static final Duration WAIT_UNTIL_COMPLETION = Duration.ofHours(1);
private static final Duration COMPLETION_CHECK_INTERVAL = Duration.ofSeconds(5);
+ @Schema(title = "Subscription ID")
+ @NotNull
+ protected Property subscriptionId;
+
@Schema(title = "Factory name")
private Property factoryName;
@@ -109,7 +118,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
Logger logger = runContext.logger();
//Authentication
- DataFactoryManager manager = this.getDataFactoryManager(runContext);
+ DataFactoryManager manager = this.dataFactoryManager(runContext);
logger.info("Successfully authenticate to Azure Data Factory");
//Create running pipeline
@@ -146,7 +155,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
final AtomicReference runningPipelineResponse = new AtomicReference<>();
try {
Await.until(() -> {
- runningPipelineResponse.set(getRunningPipeline(resourceGroupName,factoryName, runId, manager));
+ runningPipelineResponse.set(runningPipeline(resourceGroupName,factoryName, runId, manager));
String runStatus = runningPipelineResponse.get().status();
if (PIPELINE_FAILED_STATUS.contains(runStatus)) {
@@ -192,7 +201,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception {
File tempFile = runContext.workingDir().createTempFile(".ion").toFile();
try (var output = new BufferedWriter(new FileWriter(tempFile))) {
var flux = Flux.fromIterable(activities);
- Mono longMono = FileSerde.writeAll(getIonMapper(), output, flux);
+ Mono longMono = FileSerde.writeAll(ionMapper(), output, flux);
Long count = longMono.blockOptional().orElse(0L);
runContext.metric(Counter.of("activities", count));
@@ -216,11 +225,27 @@ public static class Output implements io.kestra.core.models.tasks.Output {
private URI uri;
}
- private PipelineRun getRunningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) {
+ private DataFactoryManager dataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException {
+ runContext.logger().info("Authenticating to Azure Data Factory");
+ return DataFactoryManager.authenticate(credentials(runContext), profile(runContext));
+ }
+
+ public AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException {
+ final String tenantId = this.tenantId.as(runContext, String.class);
+ final String subscriptionId = this.subscriptionId.as(runContext, String.class);
+
+ return new AzureProfile(
+ tenantId,
+ subscriptionId,
+ AzureEnvironment.AZURE
+ );
+ }
+
+ private PipelineRun runningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) {
return manager.pipelineRuns().get(resourceGroupName, factoryName, runId);
}
- private static ObjectMapper getIonMapper() {
+ private static ObjectMapper ionMapper() {
ObjectMapper ionMapper = new ObjectMapper(JacksonMapper.ofIon().getFactory());
ionMapper.setSerializationInclusion(JsonInclude.Include.ALWAYS);
ionMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
diff --git a/src/main/resources/icons/io.kestra.plugin.azure.auth.svg b/src/main/resources/icons/io.kestra.plugin.azure.auth.svg
new file mode 100644
index 0000000..38ed9e1
--- /dev/null
+++ b/src/main/resources/icons/io.kestra.plugin.azure.auth.svg
@@ -0,0 +1,23 @@
+
diff --git a/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java b/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java
new file mode 100644
index 0000000..9c3f4df
--- /dev/null
+++ b/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java
@@ -0,0 +1,50 @@
+package io.kestra.plugin.azure.auth;
+
+import io.kestra.core.junit.annotations.KestraTest;
+import io.kestra.core.models.property.Property;
+import io.kestra.core.runners.RunContextFactory;
+import jakarta.inject.Inject;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.notNullValue;
+
+@KestraTest
+class OauthAccessTokenTest {
+ @Inject
+ private RunContextFactory runContextFactory;
+
+ @Disabled("To run this test make sure your are logged in via the command line with 'az login'")
+ @Test
+ void getAccessTokenFromDefaultCredential() throws Exception {
+ OauthAccessToken task = OauthAccessToken.builder()
+ .build();
+
+ OauthAccessToken.Output run = task.run(runContextFactory.of(Collections.emptyMap()));
+
+ OauthAccessToken.AccessTokenOutput accessToken = run.getAccessToken();
+ assertThat(accessToken, notNullValue());
+ }
+
+ @Disabled("To run this test provide your service principal credentials")
+ @Test
+ void getAccessTokenFromClientSecretCredential() throws Exception {
+ final String tenantId = "";
+ final String clientId = "";
+ final String clientSecret = "";
+
+ OauthAccessToken task = OauthAccessToken.builder()
+ .tenantId(Property.of(tenantId))
+ .clientId(Property.of(clientId))
+ .clientSecret(Property.of(clientSecret))
+ .build();
+
+ OauthAccessToken.Output run = task.run(runContextFactory.of(Collections.emptyMap()));
+
+ OauthAccessToken.AccessTokenOutput accessToken = run.getAccessToken();
+ assertThat(accessToken, notNullValue());
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java b/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java
index 5784b11..3a1a7a3 100644
--- a/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java
+++ b/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java
@@ -13,7 +13,6 @@
import java.io.BufferedReader;
import java.io.InputStreamReader;
-import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;