Skip to content

Commit

Permalink
feat(oauth2): add support for trino (apache#30081)
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoferrao authored Nov 4, 2024
1 parent 64f8140 commit 305b6df
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 52 deletions.
46 changes: 23 additions & 23 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"

# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
Expand Down Expand Up @@ -525,6 +526,9 @@ def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}

return config
Expand Down Expand Up @@ -562,18 +566,16 @@ def get_oauth2_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
timeout=timeout,
)
return response.json()
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()

@classmethod
def get_oauth2_fresh_token(
Expand All @@ -586,17 +588,15 @@ def get_oauth2_fresh_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
timeout=timeout,
)
return response.json()
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()

@classmethod
def get_allows_alias_in_select(
Expand Down
27 changes: 26 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import ctx, current_app, Flask, g
import requests
from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import HttpError

from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
Expand Down Expand Up @@ -60,11 +62,28 @@
logger = logging.getLogger(__name__)


class CustomTrinoAuthErrorMeta(type):
def __instancecheck__(cls, instance: object) -> bool:
logger.info("is this being called?")
return isinstance(
instance, HttpError
) and "error 401: b'Invalid credentials'" in str(instance)


class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
pass


class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino"
engine_name = "Trino"
allows_alias_to_source_column = False

# OAuth 2.0
supports_oauth2 = True
oauth2_exception = TrinoAuthError
oauth2_token_request_type = "data"

@classmethod
def get_extra_table_metadata(
cls,
Expand Down Expand Up @@ -142,6 +161,10 @@ def update_impersonation_config( # pylint: disable=too-many-arguments
# Set principal_username=$effective_username
if backend_name == "trino" and username is not None:
connect_args["user"] = username
if access_token is not None:
http_session = requests.Session()
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
connect_args["http_session"] = http_session

@classmethod
def get_url_for_impersonation(
Expand All @@ -154,6 +177,7 @@ def get_url_for_impersonation(
"""
Return a modified URL with the username set.
:param access_token: Personal access token for OAuth2
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
Expand Down Expand Up @@ -228,6 +252,7 @@ def execute_with_cursor(
execute_result: dict[str, Any] = {}
execute_event = threading.Event()

@copy_current_request_context
def _execute(
results: dict[str, Any],
event: threading.Event,
Expand Down
4 changes: 4 additions & 0 deletions superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token.
token_request_uri: str

# Not all identity providers expect json. Keycloak expects a form encoded request,
# which in the `requests` package context means using the `data` param, not `json`.
request_content_type: str


class OAuth2TokenResponse(TypedDict, total=False):
"""
Expand Down
7 changes: 6 additions & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import backoff
import jwt
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from marshmallow import EXCLUDE, fields, post_load, Schema, validate

from superset import db
from superset.distributed_lock import KeyValueDistributedLock
Expand Down Expand Up @@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
request_content_type = fields.String(
required=False,
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)
1 change: 1 addition & 0 deletions tests/unit_tests/db_engine_specs/test_gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token",
"request_content_type": "json",
}


Expand Down
117 changes: 90 additions & 27 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -421,21 +426,23 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
Expand All @@ -446,23 +453,25 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
g.some_value = "some_value"

def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
g.some_value = "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)


def test_get_columns(mocker: MockerFixture):
Expand Down Expand Up @@ -784,3 +793,57 @@ def test_where_latest_partition(
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)


@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}


def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec

requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}

assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
1 change: 1 addition & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
}


Expand Down

0 comments on commit 305b6df

Please sign in to comment.