diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java b/src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java similarity index 56% rename from src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java rename to src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java index 514ba27..430187c 100644 --- a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnection.java +++ b/src/main/java/io/kestra/plugin/azure/AbstractAzureIdentityConnection.java @@ -1,10 +1,7 @@ -package io.kestra.plugin.azure.datafactory; +package io.kestra.plugin.azure; import com.azure.core.credential.TokenCredential; -import com.azure.core.management.AzureEnvironment; -import com.azure.core.management.profile.AzureProfile; import com.azure.identity.*; -import com.azure.resourcemanager.datafactory.DataFactoryManager; import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.property.Property; import io.kestra.core.models.tasks.Task; @@ -17,16 +14,20 @@ import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; +/** + * This class enables the creation of Azure credentials from different sources. + * For more information please refer to the Azure Identity documentation + * + */ @SuperBuilder @ToString @EqualsAndHashCode @Getter @NoArgsConstructor -public abstract class AbstractDataFactoryConnection extends Task implements AbstractDataFactoryConnectionInterface { +public abstract class AbstractAzureIdentityConnection extends Task implements AzureIdentityConnectionInterface { @NotNull - protected Property subscriptionId; - @NotNull - protected Property tenantId; + @Builder.Default + protected Property tenantId = Property.of(""); @Builder.Default protected Property clientId = Property.of(""); @@ -35,43 +36,30 @@ public abstract class AbstractDataFactoryConnection extends Task implements Abst @Builder.Default protected Property pemCertificate = Property.of(""); - protected DataFactoryManager getDataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException { - runContext.logger().info("Authenticating to Azure Data Factory"); - return DataFactoryManager.authenticate(credentials(runContext), profile(runContext)); - } - - private AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException { + public TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException { final String tenantId = this.tenantId.as(runContext, String.class); - final String subscriptionId = this.subscriptionId.as(runContext, String.class); - - return new AzureProfile( - tenantId, - subscriptionId, - AzureEnvironment.AZURE - ); - } - - private TokenCredential credentials(RunContext runContext) throws IllegalVariableEvaluationException { - final String tenantId = runContext.render(this.tenantId.as(runContext, String.class)); - final String clientId = runContext.render(this.clientId.as(runContext, String.class)); + final String clientId = this.clientId.as(runContext, String.class); //Create client/secret credentials - final String clientSecret = runContext.render(this.clientSecret.as(runContext, String.class)); + final String clientSecret = this.clientSecret.as(runContext, String.class); if(StringUtils.isNotBlank(clientSecret)) { + runContext.logger().info("Authentication is using Client Secret Credentials"); return getClientSecretCredential(tenantId, clientId, clientSecret); } //Create client/certificate credentials - final String pemCertificate = runContext.render(this.pemCertificate.as(runContext, String.class)); + final String pemCertificate = this.pemCertificate.as(runContext, String.class); if(StringUtils.isNotBlank(pemCertificate)) { + runContext.logger().info("Authentication is using Client Certificate Credentials"); return getClientCertificateCredential(tenantId, clientId, pemCertificate); } //Create default authentication + runContext.logger().info("Authentication is using Default Azure Credentials"); return new DefaultAzureCredentialBuilder().tenantId(tenantId).build(); } - private ClientCertificateCredential getClientCertificateCredential(String clientId, String tenantId, String pemCertificate) { + private ClientCertificateCredential getClientCertificateCredential(String tenantId, String clientId, String pemCertificate) { return new ClientCertificateCredentialBuilder() .clientId(clientId) .tenantId(tenantId) diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java b/src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java similarity index 87% rename from src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java rename to src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java index 6511eb8..e902e39 100644 --- a/src/main/java/io/kestra/plugin/azure/datafactory/AbstractDataFactoryConnectionInterface.java +++ b/src/main/java/io/kestra/plugin/azure/AzureIdentityConnectionInterface.java @@ -1,9 +1,9 @@ -package io.kestra.plugin.azure.datafactory; +package io.kestra.plugin.azure; import io.kestra.core.models.property.Property; import io.swagger.v3.oas.annotations.media.Schema; -public interface AbstractDataFactoryConnectionInterface { +public interface AzureIdentityConnectionInterface { @Schema( title = "Client ID", description = """ @@ -33,7 +33,4 @@ public interface AbstractDataFactoryConnectionInterface { @Schema(title = "Tenant ID") Property getTenantId(); - - @Schema(title = "Subscription ID") - Property getSubscriptionId(); } diff --git a/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java b/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java new file mode 100644 index 0000000..96ffae6 --- /dev/null +++ b/src/main/java/io/kestra/plugin/azure/auth/OauthAccessToken.java @@ -0,0 +1,95 @@ +package io.kestra.plugin.azure.auth; + +import com.azure.core.credential.*; +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.property.Property; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.models.tasks.common.EncryptedString; +import io.kestra.core.runners.RunContext; +import io.kestra.plugin.azure.AbstractAzureIdentityConnection; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; +import lombok.*; +import lombok.experimental.SuperBuilder; + +import java.time.OffsetDateTime; +import java.util.Collections; +import java.util.List; + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Plugin( + examples = { + @Example( + full = true, + code = """ + id: azure_get_token + namespace: company.team + + tasks: + - id: get_access_token + type: io.kestra.plugin.azure.oauth.OauthAccessToken + tenantId: "{{ secret('SERVICE_PRINCIPAL_TENANT_ID') }}" + clientId: "{{ secret('SERVICE_PRINCIPAL_CLIENT_ID') }}" + clientSecret: "{{ secret('SERVICE_PRINCIPAL_CLIENT_SECRET') }}" + """ + ) + } +) +@Schema( + title = "Fetch an OAuth access token." +) +public class OauthAccessToken extends AbstractAzureIdentityConnection implements RunnableTask { + @Schema(title = "The Azure scopes to be used") + @Builder.Default + Property> scopes = Property.of(Collections.singletonList("https://management.azure.com/.default")); + + @Override + public Output run(RunContext runContext) throws Exception { + TokenCredential credential = this.credentials(runContext); + + TokenRequestContext requestContext = new TokenRequestContext(); + requestContext.setScopes(scopes.asList(runContext, String.class)); + + runContext.logger().info("Retrieve access token."); + AccessToken accessToken = credential.getTokenSync(requestContext); + + runContext.logger().info("Successfully retrieved access token."); + + var output = AccessTokenOutput.builder() + .expirationTime(accessToken.getExpiresAt()) + .scopes(requestContext.getScopes()) + .tokenValue(EncryptedString.from(accessToken.getToken(), runContext)); + + return Output + .builder() + .accessToken(output.build()) + .build(); + } + + @Builder + @Getter + public static class Output implements io.kestra.core.models.tasks.Output { + @NotNull + @Schema(title = "An OAuth access token for the current user.") + private final AccessTokenOutput accessToken; + } + + @Builder + @Getter + public static class AccessTokenOutput { + List scopes; + + @Schema( + title = "OAuth access token value", + description = "Will be automatically encrypted and decrypted in the outputs if encryption is configured" + ) + EncryptedString tokenValue; + + OffsetDateTime expirationTime; + } +} diff --git a/src/main/java/io/kestra/plugin/azure/auth/package-info.java b/src/main/java/io/kestra/plugin/azure/auth/package-info.java new file mode 100644 index 0000000..49fc654 --- /dev/null +++ b/src/main/java/io/kestra/plugin/azure/auth/package-info.java @@ -0,0 +1,8 @@ +@PluginSubGroup( + title = "Authentication", + description = "This sub-group of plugins contains tasks to manage authentication for Azure services.", + categories = { PluginSubGroup.PluginCategory.CLOUD } +) +package io.kestra.plugin.azure.auth; + +import io.kestra.core.models.annotations.PluginSubGroup; \ No newline at end of file diff --git a/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java b/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java index 2b3405e..8b94cae 100644 --- a/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java +++ b/src/main/java/io/kestra/plugin/azure/datafactory/CreateRun.java @@ -1,5 +1,7 @@ package io.kestra.plugin.azure.datafactory; +import com.azure.core.management.AzureEnvironment; +import com.azure.core.management.profile.AzureProfile; import com.azure.core.util.Context; import com.azure.resourcemanager.datafactory.DataFactoryManager; import com.azure.resourcemanager.datafactory.models.ActivityRun; @@ -11,6 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.SerializationFeature; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.executions.metrics.Counter; @@ -21,8 +24,10 @@ import io.kestra.core.serializers.FileSerde; import io.kestra.core.serializers.JacksonMapper; import io.kestra.core.utils.Await; +import io.kestra.plugin.azure.AbstractAzureIdentityConnection; import io.netty.handler.codec.http.HttpResponseStatus; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; import lombok.*; import lombok.experimental.SuperBuilder; import lombok.extern.slf4j.Slf4j; @@ -76,12 +81,16 @@ description = "Launch an Azure DataFactory pipeline from Kestra. " + "Data Factory contains a series of interconnected systems that provide a complete end-to-end platform for data engineers." ) -public class CreateRun extends AbstractDataFactoryConnection implements RunnableTask { +public class CreateRun extends AbstractAzureIdentityConnection implements RunnableTask { private static final String PIPELINE_SUCCEEDED_STATUS = "Succeeded"; private static final List PIPELINE_FAILED_STATUS = List.of("Failed", "Canceling", "Cancelled"); private static final Duration WAIT_UNTIL_COMPLETION = Duration.ofHours(1); private static final Duration COMPLETION_CHECK_INTERVAL = Duration.ofSeconds(5); + @Schema(title = "Subscription ID") + @NotNull + protected Property subscriptionId; + @Schema(title = "Factory name") private Property factoryName; @@ -109,7 +118,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception { Logger logger = runContext.logger(); //Authentication - DataFactoryManager manager = this.getDataFactoryManager(runContext); + DataFactoryManager manager = this.dataFactoryManager(runContext); logger.info("Successfully authenticate to Azure Data Factory"); //Create running pipeline @@ -146,7 +155,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception { final AtomicReference runningPipelineResponse = new AtomicReference<>(); try { Await.until(() -> { - runningPipelineResponse.set(getRunningPipeline(resourceGroupName,factoryName, runId, manager)); + runningPipelineResponse.set(runningPipeline(resourceGroupName,factoryName, runId, manager)); String runStatus = runningPipelineResponse.get().status(); if (PIPELINE_FAILED_STATUS.contains(runStatus)) { @@ -192,7 +201,7 @@ public CreateRun.Output run(RunContext runContext) throws Exception { File tempFile = runContext.workingDir().createTempFile(".ion").toFile(); try (var output = new BufferedWriter(new FileWriter(tempFile))) { var flux = Flux.fromIterable(activities); - Mono longMono = FileSerde.writeAll(getIonMapper(), output, flux); + Mono longMono = FileSerde.writeAll(ionMapper(), output, flux); Long count = longMono.blockOptional().orElse(0L); runContext.metric(Counter.of("activities", count)); @@ -216,11 +225,27 @@ public static class Output implements io.kestra.core.models.tasks.Output { private URI uri; } - private PipelineRun getRunningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) { + private DataFactoryManager dataFactoryManager(RunContext runContext) throws IllegalVariableEvaluationException { + runContext.logger().info("Authenticating to Azure Data Factory"); + return DataFactoryManager.authenticate(credentials(runContext), profile(runContext)); + } + + public AzureProfile profile(RunContext runContext) throws IllegalVariableEvaluationException { + final String tenantId = this.tenantId.as(runContext, String.class); + final String subscriptionId = this.subscriptionId.as(runContext, String.class); + + return new AzureProfile( + tenantId, + subscriptionId, + AzureEnvironment.AZURE + ); + } + + private PipelineRun runningPipeline(String resourceGroupName, String factoryName, String runId, DataFactoryManager manager) { return manager.pipelineRuns().get(resourceGroupName, factoryName, runId); } - private static ObjectMapper getIonMapper() { + private static ObjectMapper ionMapper() { ObjectMapper ionMapper = new ObjectMapper(JacksonMapper.ofIon().getFactory()); ionMapper.setSerializationInclusion(JsonInclude.Include.ALWAYS); ionMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); diff --git a/src/main/resources/icons/io.kestra.plugin.azure.auth.svg b/src/main/resources/icons/io.kestra.plugin.azure.auth.svg new file mode 100644 index 0000000..38ed9e1 --- /dev/null +++ b/src/main/resources/icons/io.kestra.plugin.azure.auth.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + diff --git a/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java b/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java new file mode 100644 index 0000000..9c3f4df --- /dev/null +++ b/src/test/java/io/kestra/plugin/azure/auth/OauthAccessTokenTest.java @@ -0,0 +1,50 @@ +package io.kestra.plugin.azure.auth; + +import io.kestra.core.junit.annotations.KestraTest; +import io.kestra.core.models.property.Property; +import io.kestra.core.runners.RunContextFactory; +import jakarta.inject.Inject; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.notNullValue; + +@KestraTest +class OauthAccessTokenTest { + @Inject + private RunContextFactory runContextFactory; + + @Disabled("To run this test make sure your are logged in via the command line with 'az login'") + @Test + void getAccessTokenFromDefaultCredential() throws Exception { + OauthAccessToken task = OauthAccessToken.builder() + .build(); + + OauthAccessToken.Output run = task.run(runContextFactory.of(Collections.emptyMap())); + + OauthAccessToken.AccessTokenOutput accessToken = run.getAccessToken(); + assertThat(accessToken, notNullValue()); + } + + @Disabled("To run this test provide your service principal credentials") + @Test + void getAccessTokenFromClientSecretCredential() throws Exception { + final String tenantId = ""; + final String clientId = ""; + final String clientSecret = ""; + + OauthAccessToken task = OauthAccessToken.builder() + .tenantId(Property.of(tenantId)) + .clientId(Property.of(clientId)) + .clientSecret(Property.of(clientSecret)) + .build(); + + OauthAccessToken.Output run = task.run(runContextFactory.of(Collections.emptyMap())); + + OauthAccessToken.AccessTokenOutput accessToken = run.getAccessToken(); + assertThat(accessToken, notNullValue()); + } +} \ No newline at end of file diff --git a/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java b/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java index 5784b11..3a1a7a3 100644 --- a/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java +++ b/src/test/java/io/kestra/plugin/azure/datafactory/CreateRunTest.java @@ -13,7 +13,6 @@ import java.io.BufferedReader; import java.io.InputStreamReader; -import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map;