Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Oracle to airflow->cosmos profile map #1190

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ License
_______

`Apache License 2.0 <https://github.com/astronomer/astronomer-cosmos/blob/main/LICENSE>`_


Privacy Notice
______________

This project follows `Astronomer's Privacy Policy <https://www.astronomer.io/privacy/>`_
42 changes: 42 additions & 0 deletions cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,40 @@
DbtTestAwsEksOperator = MissingPackage("cosmos.operators.azure_container_instance.DbtTestAwsEksOperator", "aws_eks")


try:
from cosmos.operators.gcp_cloud_run_job import (
DbtBuildGcpCloudRunJobOperator,
DbtLSGcpCloudRunJobOperator,
DbtRunGcpCloudRunJobOperator,
DbtRunOperationGcpCloudRunJobOperator,
DbtSeedGcpCloudRunJobOperator,
DbtSnapshotGcpCloudRunJobOperator,
DbtTestGcpCloudRunJobOperator,
)
except (ImportError, AttributeError):
DbtBuildGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtBuildGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtLSGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtLSGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtRunGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtRunGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtRunOperationGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtRunOperationGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtSeedGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtSeedGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtSnapshotGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtSnapshotGcpCloudRunJobOperator", "gcp-cloud-run-job"
)
DbtTestGcpCloudRunJobOperator = MissingPackage(
"cosmos.operators.gcp_cloud_run_job.DbtTestGcpCloudRunJobOperator", "gcp-cloud-run-job"
)


__all__ = [
"ProjectConfig",
"ProfileConfig",
Expand Down Expand Up @@ -221,6 +255,14 @@
"DbtSeedAwsEksOperator",
"DbtSnapshotAwsEksOperator",
"DbtTestAwsEksOperator",
# GCP Cloud Run Job Execution Mode
"DbtBuildGcpCloudRunJobOperator",
"DbtLSGcpCloudRunJobOperator",
"DbtRunGcpCloudRunJobOperator",
"DbtRunOperationGcpCloudRunJobOperator",
"DbtSeedGcpCloudRunJobOperator",
"DbtSnapshotGcpCloudRunJobOperator",
"DbtTestGcpCloudRunJobOperator",
]

"""
Expand Down
1 change: 1 addition & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ExecutionMode(Enum):
AWS_EKS = "aws_eks"
VIRTUALENV = "virtualenv"
AZURE_CONTAINER_INSTANCE = "azure_container_instance"
GCP_CLOUD_RUN_JOB = "gcp_cloud_run_job"


class InvocationMode(Enum):
Expand Down
7 changes: 3 additions & 4 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ def validate_initial_user_config(
:param render_config: Configuration related to how to convert the dbt workflow into an Airflow DAG
:param operator_args: Arguments to pass to the underlying operators.
"""
if profile_config is None and execution_config.execution_mode not in (
ExecutionMode.KUBERNETES,
ExecutionMode.AWS_EKS,
ExecutionMode.DOCKER,
if profile_config is None and execution_config.execution_mode in (
ExecutionMode.LOCAL,
ExecutionMode.VIRTUALENV,
):
raise CosmosValueError(f"The profile_config is mandatory when using {execution_config.execution_mode}")

Expand Down
172 changes: 172 additions & 0 deletions cosmos/operators/gcp_cloud_run_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import inspect
from typing import Any, Callable, Sequence

from airflow.utils.context import Context

from cosmos.config import ProfileConfig
from cosmos.log import get_logger
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
DbtLSMixin,
DbtRunMixin,
DbtRunOperationMixin,
DbtSeedMixin,
DbtSnapshotMixin,
DbtTestMixin,
)

logger = get_logger(__name__)

DEFAULT_ENVIRONMENT_VARIABLES: dict[str, str] = {}

try:
from airflow.providers.google.cloud.operators.cloud_run import CloudRunExecuteJobOperator

# The overrides parameter needed to pass the dbt command was added in apache-airflow-providers-google==10.13.0
init_signature = inspect.signature(CloudRunExecuteJobOperator.__init__)
if "overrides" not in init_signature.parameters:
raise AttributeError(
"CloudRunExecuteJobOperator does not have `overrides` attribute. Ensure you've installed apache-airflow-providers-google of at least 10.11.0 "
"separately or with `pip install astronomer-cosmos[...,gcp-cloud-run-job]`."
)
except ImportError:
raise ImportError(

Check warning on line 36 in cosmos/operators/gcp_cloud_run_job.py

View check run for this annotation

Codecov / codecov/patch

cosmos/operators/gcp_cloud_run_job.py#L36

Added line #L36 was not covered by tests
"Could not import CloudRunExecuteJobOperator. Ensure you've installed the Google Cloud provider "
"separately or with `pip install astronomer-cosmos[...,gcp-cloud-run-job]`."
)


class DbtGcpCloudRunJobBaseOperator(AbstractDbtBaseOperator, CloudRunExecuteJobOperator): # type: ignore
"""
Executes a dbt core cli command in a Cloud Run Job instance with dbt installed in it.

"""

template_fields: Sequence[str] = tuple(
list(AbstractDbtBaseOperator.template_fields) + list(CloudRunExecuteJobOperator.template_fields)
)

intercept_flag = False

def __init__(
self,
# arguments required by CloudRunExecuteJobOperator
project_id: str,
region: str,
job_name: str,
#
profile_config: ProfileConfig | None = None,
command: list[str] | None = None,
environment_variables: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.profile_config = profile_config
self.command = command
self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES
super().__init__(project_id=project_id, region=region, job_name=job_name, **kwargs)

def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any:
self.build_command(context, cmd_flags)
self.log.info(f"Running command: {self.command}")
result = CloudRunExecuteJobOperator.execute(self, context)
logger.info(result)

def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None:
# For the first round, we're going to assume that the command is dbt
# This means that we don't have openlineage support, but we will create a ticket
# to add that in the future
self.dbt_executable_path = "dbt"
dbt_cmd, env_vars = self.build_cmd(context=context, cmd_flags=cmd_flags)
self.environment_variables = {**env_vars, **self.environment_variables}
self.command = dbt_cmd
# Override Cloud Run Job default arguments with dbt command
self.overrides = {
"container_overrides": [
{
"args": self.command,
"env": [{"name": key, "value": value} for key, value in self.environment_variables.items()],
}
],
}


class DbtBuildGcpCloudRunJobOperator(DbtBuildMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtGcpCloudRunJobBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtLSGcpCloudRunJobOperator(DbtLSMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core ls command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSeedGcpCloudRunJobOperator(DbtSeedMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core seed command.

:param full_refresh: dbt optional arg - dbt will treat incremental models as table models
"""

template_fields: Sequence[str] = DbtGcpCloudRunJobBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtSnapshotGcpCloudRunJobOperator(DbtSnapshotMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core snapshot command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtRunGcpCloudRunJobOperator(DbtRunMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtGcpCloudRunJobBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


class DbtTestGcpCloudRunJobOperator(DbtTestMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core test command.
"""

def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None:
super().__init__(**kwargs)
# as of now, on_warning_callback in docker executor does nothing
self.on_warning_callback = on_warning_callback


class DbtRunOperationGcpCloudRunJobOperator(DbtRunOperationMixin, DbtGcpCloudRunJobBaseOperator):
"""
Executes a dbt core run-operation command.

:param macro_name: name of macro to execute
:param args: Supply arguments to the macro. This dictionary will be mapped to the keyword arguments defined in the
selected macro.
"""

template_fields: Sequence[str] = DbtGcpCloudRunJobBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator]

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
3 changes: 2 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import tempfile
import urllib.parse
import warnings
from abc import ABC, abstractmethod
from functools import cached_property
Expand Down Expand Up @@ -449,7 +450,7 @@ def get_datasets(self, source: Literal["inputs", "outputs"]) -> list[Dataset]:
uris = []
for completed in self.openlineage_events_completes:
for output in getattr(completed, source):
dataset_uri = output.namespace + "/" + output.name
dataset_uri = output.namespace + "/" + urllib.parse.quote(output.name)
uris.append(dataset_uri)
self.log.debug("URIs to be converted to Dataset: %s", uris)

Expand Down
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .databricks.oauth import DatabricksOauthProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .oracle.user_pass import OracleUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
Expand All @@ -34,6 +35,7 @@
GoogleCloudOauthProfileMapping,
DatabricksTokenProfileMapping,
DatabricksOauthProfileMapping,
OracleUserPasswordProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
Expand Down Expand Up @@ -77,6 +79,7 @@ def get_automatic_profile_mapping(
"DatabricksTokenProfileMapping",
"DatabricksOauthProfileMapping",
"DbtProfileConfigVars",
"OracleUserPasswordProfileMapping",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/oracle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Oracle Airflow connection -> dbt profile mappings"""

from .user_pass import OracleUserPasswordProfileMapping

__all__ = ["OracleUserPasswordProfileMapping"]
89 changes: 89 additions & 0 deletions cosmos/profiles/oracle/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Maps Airflow Oracle connections using user + password authentication to dbt profiles."""

from __future__ import annotations

import re
from typing import Any

from ..base import BaseProfileMapping


class OracleUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Oracle connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/oracle-setup
https://airflow.apache.org/docs/apache-airflow-providers-oracle/stable/connections/oracle.html
"""

airflow_connection_type: str = "oracle"
dbt_profile_type: str = "oracle"
is_community: bool = True

required_fields = [
"user",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"port": "port",
"service": "extra.service_name",
"user": "login",
"password": "password",
"database": "extra.service_name",
"connection_string": "extra.dsn",
}

@property
def env_vars(self) -> dict[str, str]:
"""Set oracle thick mode."""
env_vars = super().env_vars
if self._get_airflow_conn_field("extra.thick_mode"):
env_vars["ORA_PYTHON_DRIVER_TYPE"] = "thick"

Check warning on line 44 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L44

Added line #L44 was not covered by tests
return env_vars

@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile. The password is stored in an environment variable."""
profile = {
"protocol": "tcp",
"port": 1521,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

if "schema" not in profile and "user" in profile:
proxy = re.search(r"\[([^]]+)\]", profile["user"])
if proxy:
profile["schema"] = proxy.group(1)

Check warning on line 62 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L60-L62

Added lines #L60 - L62 were not covered by tests
else:
profile["schema"] = profile["user"]

Check warning on line 64 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L64

Added line #L64 was not covered by tests
if "schema" in self.profile_args:
profile["schema"] = self.profile_args["schema"]

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"""Gets mock profile. Defaults port to 1521."""
profile_dict = {

Check warning on line 73 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L73

Added line #L73 was not covered by tests
"protocol": "tcp",
"port": 1521,
**super().mock_profile,
}

if "schema" not in profile_dict and "user" in profile_dict:
proxy = re.search(r"\[([^]]+)\]", profile_dict["user"])
if proxy:
profile_dict["schema"] = proxy.group(1)

Check warning on line 82 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L79-L82

Added lines #L79 - L82 were not covered by tests
else:
profile_dict["schema"] = profile_dict["user"]

Check warning on line 84 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L84

Added line #L84 was not covered by tests

user_defined_schema = self.profile_args.get("schema")
if user_defined_schema:
profile_dict["schema"] = user_defined_schema
return profile_dict

Check warning on line 89 in cosmos/profiles/oracle/user_pass.py

View check run for this annotation

Codecov / codecov/patch

cosmos/profiles/oracle/user_pass.py#L86-L89

Added lines #L86 - L89 were not covered by tests
Loading
Loading