Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow CORS requests to /api/workflow_landings #18963

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ concurrency:
jobs:
test:
name: Test
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
Expand Down
10 changes: 10 additions & 0 deletions lib/galaxy/webapps/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,13 @@ def include_all_package_routers(app: FastAPI, package_name: str):
router = getattr(module, "router", None)
if router:
app.include_router(router, responses=responses)

# handle CORS preflight requests - synchronize with wsgi behavior.
# this needs to happen last so it doesn't clobber routes with explicit cors handling
# it doesn't affect the CORS middleware since the middleware terminates the request handling before routing
@app.options("/api/{rest_of_path:path}")
async def preflight_handler(request: Request, rest_of_path: str) -> Response:
response = Response()
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Max-Age"] = "600"
return response
54 changes: 53 additions & 1 deletion lib/galaxy/webapps/galaxy/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
Any,
AsyncGenerator,
Callable,
cast,
NamedTuple,
Optional,
Expand Down Expand Up @@ -379,6 +380,18 @@ def get_admin_user(trans: SessionRequestContext = DependsOnTrans):
AdminUserRequired = Depends(get_admin_user)


def cors_preflight(response: Response):
response.headers["Access-Control-Allow-Origin"] = "*"
# Only allow CORS safe-listed headers for now (https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_request_header)
response.headers["Access-Control-Allow-Headers"] = "Accept,Accept-Language,Content-Language,Content-Type,Range"
response.headers["Access-Control-Max-Age"] = "600"
response.status_code = 200
return response


CORSPreflightRequired = Depends(cors_preflight)


class BaseGalaxyAPIController(BaseAPIController):
def __init__(self, app: StructuredApp):
super().__init__(app)
Expand All @@ -401,14 +414,21 @@ class FrameworkRouter(APIRouter):

def wrap_with_alias(self, verb: RestVerb, *args, alias: Optional[str] = None, **kwd):
"""
Wraps FastAPI methods with additional alias keyword and require_admin handling.
Wraps FastAPI methods with additional alias keyword, require_admin and CORS handling.

@router.get("/api/thing", alias="/api/deprecated_thing") will then create
routes for /api/thing and /api/deprecated_thing.
"""
kwd = self._handle_galaxy_kwd(kwd)
include_in_schema = kwd.pop("include_in_schema", True)

allow_cors = kwd.pop("allow_cors", False)
if allow_cors:
assert (
"route_class_override" not in kwd
), "Cannot use allow_cors=True on route and specify `route_class_override`"
kwd["route_class_override"] = APICorsRoute

def decorate_route(route, include_in_schema=include_in_schema):
# Decorator solely exists to allow passing `route_class_override` to add_api_route
def decorated_route(func):
Expand All @@ -419,6 +439,21 @@ def decorated_route(func):
include_in_schema=include_in_schema,
**kwd,
)

if allow_cors:

dependencies = kwd.pop("dependencies", [])
dependencies.append(CORSPreflightRequired)

self.add_api_route(
route,
endpoint=lambda: None,
methods=[RestVerb.options],
include_in_schema=False,
dependencies=dependencies,
**kwd,
)

return func

return decorated_route
Expand Down Expand Up @@ -504,6 +539,23 @@ class Router(FrameworkRouter):
user_dependency = DependsOnUser


class APICorsRoute(APIRoute):
"""
Sends CORS headers
"""

def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
response: Response = await original_route_handler(request)
response.headers["Access-Control-Allow-Origin"] = request.headers.get("Origin", "*")
response.headers["Access-Control-Max-Age"] = "600"
return response

return custom_route_handler


class APIContentTypeRoute(APIRoute):
"""
Determines endpoint to match using content-type.
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def show_workflow(
) -> StoredWorkflowDetailed:
return self.service.show_workflow(trans, workflow_id, instance, legacy, version)

@router.post("/api/workflow_landings", public=True)
@router.post("/api/workflow_landings", public=True, allow_cors=True)
def create_landing(
self,
trans: ProvidesUserContext = DependsOnTrans,
Expand Down
9 changes: 0 additions & 9 deletions lib/galaxy/webapps/galaxy/fast_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from fastapi.openapi.constants import REF_TEMPLATE
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response

from galaxy.schema.generics import CustomJsonSchema
from galaxy.version import VERSION
Expand Down Expand Up @@ -121,14 +120,6 @@ async def add_x_frame_options(request: Request, call_next):
allow_methods=["*"],
max_age=600,
)
else:
# handle CORS preflight requests - synchronize with wsgi behavior.
@app.options("/api/{rest_of_path:path}")
async def preflight_handler(request: Request, rest_of_path: str) -> Response:
response = Response()
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Max-Age"] = "600"
return response


def include_legacy_openapi(app, gx_app):
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy_test/base/populators.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def create_workflow_landing(self, payload: CreateWorkflowLandingRequestPayload)
json = payload.model_dump(mode="json")
create_response = self._post(create_url, json, json=True, anon=True)
api_asserts.assert_status_code_is(create_response, 200)
assert create_response.headers["access-control-allow-origin"]
create_response.raise_for_status()
return WorkflowLandingRequest.model_validate(create_response.json())

Expand Down
42 changes: 33 additions & 9 deletions test/integration/test_web_framework_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

from galaxy_test.driver import integration_util

ENDPOINT_WITH_CORS = "workflow_landings"
ENDPOINT_WITHOUT_EXPLICIT_CORS = "licenses"
WSGI_ENDPOINT = "tools"


class BaseWebFrameworkTestCase(integration_util.IntegrationTestCase):
def _options(self, headers=None):
url = self._api_url("licenses")
def _options(self, headers=None, endpoint=ENDPOINT_WITH_CORS):
url = self._api_url(endpoint)
options_response = options(url, headers=headers or {})
return options_response

Expand All @@ -18,8 +22,18 @@ def test_options(self):
"Access-Control-Request-Method": "GET",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers)
assert options_response.status_code == 200
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_options_wsgi(self):
# Tests legacy handling
headers = {
"Access-Control-Request-Method": "GET",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers, WSGI_ENDPOINT)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_origin_not_allowed_default(self):
Expand All @@ -28,10 +42,20 @@ def test_origin_not_allowed_default(self):
"Access-Control-Request-Headers": "Authorization",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers)
assert options_response.status_code == 200
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_origin_explicitly_allowed(self):
headers = {
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Authorization",
"Origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers, ENDPOINT_WITH_CORS)
options_response.raise_for_status()
assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083"


class TestAllowOriginIntegration(BaseWebFrameworkTestCase):
@classmethod
Expand All @@ -45,7 +69,7 @@ def test_origin_allowed_if_configured(self):
"origin": "http://192.168.0.101:8083",
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" in options_response.headers
assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083"
Expand All @@ -57,7 +81,7 @@ def test_origin_allowed_if_configured_via_regex(self):
"origin": "http://rna.galaxyproject.org",
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" in options_response.headers
assert options_response.headers["access-control-allow-origin"] == "http://rna.galaxyproject.org"
Expand All @@ -69,5 +93,5 @@ def test_origin_not_allowed_if_not_in_configured_list(self):
"origin": "http://192.168.0.102:8083", # swapped ip by one
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
assert options_response.status_code == 400
Loading