From bdadb93e185df3084f0ed4eaab75491c4494b387 Mon Sep 17 00:00:00 2001 From: Uwe Winter Date: Mon, 3 Nov 2025 20:05:57 +1100 Subject: [PATCH] feature: oauth login --- README.md | 56 +- mcp-server-galaxy-py/Dockerfile | 17 +- mcp-server-galaxy-py/README.md | 64 +- mcp-server-galaxy-py/USAGE_EXAMPLES.md | 14 +- mcp-server-galaxy-py/pyproject.toml | 2 + .../src/galaxy_mcp/__main__.py | 37 +- mcp-server-galaxy-py/src/galaxy_mcp/auth.py | 892 ++++++++++++++++++ mcp-server-galaxy-py/src/galaxy_mcp/server.py | 388 ++++++-- mcp-server-galaxy-py/tests/README.md | 1 + mcp-server-galaxy-py/tests/TEST_STRATEGY.md | 1 + mcp-server-galaxy-py/tests/test_connection.py | 36 +- mcp-server-galaxy-py/tests/test_oauth.py | 100 ++ mcp-server-galaxy-py/uv.lock | 8 +- 13 files changed, 1488 insertions(+), 128 deletions(-) create mode 100644 mcp-server-galaxy-py/src/galaxy_mcp/auth.py create mode 100644 mcp-server-galaxy-py/tests/test_oauth.py diff --git a/README.md b/README.md index 966cb6e..531276c 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Note: There is also a work-in-progress TypeScript implementation available in a ## Key Features - **Galaxy Connection**: Connect to any Galaxy instance with a URL and API key +- **OAuth Login (optional)**: Offer browser-based sign-in that exchanges credentials for temporary Galaxy API keys - **Server Information**: Retrieve comprehensive server details including version, configuration, and capabilities - **Tools Management**: Search, view details, and execute Galaxy tools - **Workflow Integration**: Access and import workflows from the Interactive Workflow Composer (IWC) @@ -20,36 +21,65 @@ Note: There is also a work-in-progress TypeScript implementation available in a ## Quick Start -The fastest way to get started is using `uvx`: +The `galaxy-mcp` CLI ships with both stdio (local) and HTTP transports. Choose the setup that +matches your client: ```bash -# Run the server directly without installation +# Stdio transport (default) – great for local development tools uvx galaxy-mcp -# Run with MCP developer tools for interactive exploration -uvx --from galaxy-mcp mcp dev galaxy_mcp.server - -# Run as a deployed MCP server -uvx --from galaxy-mcp mcp run galaxy_mcp.server +# HTTP transport with OAuth (for remote/browser clients) +export GALAXY_URL="https://usegalaxy.org.au/" # Target Galaxy instance +export GALAXY_MCP_PUBLIC_URL="https://mcp.example.com" # Public base URL for OAuth redirects +export GALAXY_MCP_SESSION_SECRET="$(openssl rand -hex 32)" +uvx galaxy-mcp --transport streamable-http --host 0.0.0.0 --port 8000 ``` -You'll need to set up your Galaxy credentials via environment variables: +When running over stdio you can provide long-lived credentials via environment variables: ```bash -export GALAXY_URL= -export GALAXY_API_KEY= +export GALAXY_URL="https://usegalaxy.org/" +export GALAXY_API_KEY="your-api-key" ``` +For OAuth flows the server exchanges user credentials for short-lived Galaxy API keys on demand, so +you typically leave `GALAXY_API_KEY` unset. + ### Alternative Installation ```bash # Install from PyPI pip install galaxy-mcp -# Or from source +# Run (stdio by default) +galaxy-mcp + +# Or from source using uv cd mcp-server-galaxy-py -pip install -r requirements.txt -mcp run main.py +uv sync +uv run galaxy-mcp --transport streamable-http --host 0.0.0.0 --port 8000 +``` + +## Container Usage + +The published image defaults to stdio transport (no HTTP listener): + +```bash +docker run --rm -it \ + -e GALAXY_URL="https://usegalaxy.org/" \ + -e GALAXY_API_KEY="your-api-key" \ + galaxyproject/galaxy-mcp +``` + +For OAuth + HTTP: + +```bash +docker run --rm -it -p 8000:8000 \ + -e GALAXY_URL="https://usegalaxy.org.au/" \ + -e GALAXY_MCP_TRANSPORT="streamable-http" \ + -e GALAXY_MCP_PUBLIC_URL="https://mcp.example.com" \ + -e GALAXY_MCP_SESSION_SECRET="$(openssl rand -hex 32)" \ + galaxyproject/galaxy-mcp ``` ## Development Guidelines diff --git a/mcp-server-galaxy-py/Dockerfile b/mcp-server-galaxy-py/Dockerfile index 05f4df0..ee9db35 100644 --- a/mcp-server-galaxy-py/Dockerfile +++ b/mcp-server-galaxy-py/Dockerfile @@ -41,9 +41,12 @@ RUN groupadd -r app && useradd -r -g app app WORKDIR /app -# Define default environment variables -ENV GALAXY_INSTANCE="https://usegalaxy.org" \ - GALAXY_API_KEY="" +# Define default environment variables (stdio transport by default) +ENV GALAXY_URL="https://usegalaxy.org" \ + GALAXY_API_KEY="" \ + GALAXY_MCP_TRANSPORT="stdio" \ + GALAXY_MCP_HOST="0.0.0.0" \ + GALAXY_MCP_PORT="8000" COPY --from=uv /root/.local /root/.local COPY --from=uv --chown=app:app /app/.venv /app/.venv @@ -63,12 +66,12 @@ USER app # Expose service port EXPOSE 8000 -# Add healthcheck +# Add healthcheck that verifies the HTTP listener is accepting connections HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ - CMD mcp-server-galaxy --health-check || exit 1 + CMD python -c "import os, socket, sys; addr=('127.0.0.1', int(os.environ.get('GALAXY_MCP_PORT', '8000'))); sock=socket.socket(); sock.settimeout(2); sock.connect(addr); sock.close()" # Use tini as init to handle signals properly ENTRYPOINT ["/usr/bin/tini", "--"] -# Update command to use environment variables instead of DB path -CMD ["mcp-server-galaxy", "--galaxy-url", "${GALAXY_INSTANCE}", "--api-key", "${GALAXY_API_KEY}"] +# Start the MCP server (transport/host/port taken from environment variables) +CMD ["galaxy-mcp"] diff --git a/mcp-server-galaxy-py/README.md b/mcp-server-galaxy-py/README.md index 71b3993..6a42248 100644 --- a/mcp-server-galaxy-py/README.md +++ b/mcp-server-galaxy-py/README.md @@ -5,6 +5,7 @@ This is the Python implementation of the Galaxy MCP server, providing a Model Co ## Features - Complete Galaxy API integration through BioBlend +- Optional OAuth login flow for HTTP deployments - Interactive Workflow Composer (IWC) integration - FastMCP2 server with remote deployment support - Type-annotated Python codebase @@ -39,54 +40,65 @@ uv sync --all-extras ## Configuration -The server requires Galaxy credentials to connect to an instance. You can provide these via environment variables: +At minimum the server needs to know which Galaxy instance to target: ```bash -export GALAXY_URL= -export GALAXY_API_KEY= +export GALAXY_URL="https://usegalaxy.org.au/" ``` -Alternatively, create a `.env` file in the project root with these variables. +How you authenticate depends on your transport: -## Usage +- **Stdio / long-lived sessions** – provide an API key: + + ```bash + export GALAXY_API_KEY="your-api-key" + ``` + +- **HTTP / OAuth** – configure the public URL that users reach and a signing secret for session + tokens. The server mints short-lived Galaxy API keys on behalf of each user. -### Quick Start with uvx + ```bash + export GALAXY_MCP_PUBLIC_URL="https://mcp.example.com" + export GALAXY_MCP_SESSION_SECRET="$(openssl rand -hex 32)" + ``` -The fastest way to run the Galaxy MCP server is using `uvx`: + Optionally set `GALAXY_MCP_CLIENT_REGISTRY` to control where OAuth client registrations are stored. + +You can also steer the transport with `GALAXY_MCP_TRANSPORT` (`stdio`, `streamable-http`, or `sse`). +All variables can be placed in a `.env` file for convenience. + +## Usage + +### Quick Start with `uvx` ```bash -# Run the server directly without installation +# Local stdio transport (no network listener) uvx galaxy-mcp -# Run with FastMCP2 dev tools -uvx --from galaxy-mcp fastmcp dev src/galaxy_mcp/server.py - -# Run as remote server -uvx --from galaxy-mcp fastmcp run src/galaxy_mcp/server.py --transport sse --port 8000 +# Remote/browser clients with HTTP + OAuth +export GALAXY_URL="https://usegalaxy.org.au/" +export GALAXY_MCP_PUBLIC_URL="https://mcp.example.com" +export GALAXY_MCP_SESSION_SECRET="$(openssl rand -hex 32)" +uvx galaxy-mcp --transport streamable-http --host 0.0.0.0 --port 8000 ``` -### As a standalone MCP server +### Installed CLI ```bash -# Install and run the MCP server pip install galaxy-mcp -galaxy-mcp - -# The server will wait for MCP protocol messages on stdin +galaxy-mcp --transport streamable-http --host 0.0.0.0 --port 8000 ``` -### With MCP clients +If `--transport` is omitted the server defaults to stdio and reads/writes MCP messages via stdin/stdout. -```bash -# Use with FastMCP2 CLI tools -fastmcp dev src/galaxy_mcp/server.py -fastmcp run src/galaxy_mcp/server.py +### Working from a checkout -# Use with other MCP-compatible clients -your-mcp-client galaxy-mcp +```bash +uv sync +uv run galaxy-mcp --transport streamable-http --host 0.0.0.0 --port 8000 ``` -See [USAGE_EXAMPLES.md](USAGE_EXAMPLES.md) for detailed usage patterns and common examples. +See [USAGE_EXAMPLES.md](USAGE_EXAMPLES.md) for detailed tool usage patterns. ## Available MCP Tools diff --git a/mcp-server-galaxy-py/USAGE_EXAMPLES.md b/mcp-server-galaxy-py/USAGE_EXAMPLES.md index 8048478..797affe 100644 --- a/mcp-server-galaxy-py/USAGE_EXAMPLES.md +++ b/mcp-server-galaxy-py/USAGE_EXAMPLES.md @@ -6,17 +6,25 @@ This document provides common usage patterns and examples for the Galaxy MCP ser ### 1. Connect to Galaxy -First, you need to establish a connection to your Galaxy instance: +For stdio deployments you can authenticate with a long-lived API key: ```python -# Option 1: Use environment variables (recommended) # Set GALAXY_URL and GALAXY_API_KEY in your environment or .env file connect() -# Option 2: Provide credentials directly +# Or provide credentials directly connect(url="https://your-galaxy-instance.org", api_key="your-api-key") ``` +For HTTP deployments that use OAuth the active session is resolved automatically. Calling +`connect()` without arguments simply confirms the session and returns user details: + +```python +session_info = connect() +assert session_info["auth"] == "oauth" +print(session_info["user"]["username"]) +``` + #### Get server information Once connected, you can retrieve comprehensive information about the Galaxy server: diff --git a/mcp-server-galaxy-py/pyproject.toml b/mcp-server-galaxy-py/pyproject.toml index 5fab9e7..c42edfd 100644 --- a/mcp-server-galaxy-py/pyproject.toml +++ b/mcp-server-galaxy-py/pyproject.toml @@ -26,7 +26,9 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ + "anyio>=4.0.0", "bioblend>=1.5.0", + "cryptography>=41.0.0", "fastmcp>=2.3.0", "requests>=2.32.3", "python-dotenv>=1.0.0", diff --git a/mcp-server-galaxy-py/src/galaxy_mcp/__main__.py b/mcp-server-galaxy-py/src/galaxy_mcp/__main__.py index acc10fc..d6d60f9 100644 --- a/mcp-server-galaxy-py/src/galaxy_mcp/__main__.py +++ b/mcp-server-galaxy-py/src/galaxy_mcp/__main__.py @@ -1,12 +1,41 @@ """Command-line entry point for Galaxy MCP server.""" +import argparse +import os + from . import server -def run(): - """Run the MCP server.""" - # Use FastMCP's simplified run method - server.mcp.run() +def run() -> None: + """Run the MCP server using stdio or HTTP transport.""" + parser = argparse.ArgumentParser(description="Run the Galaxy MCP server.") + parser.add_argument( + "--transport", + choices=["stdio", "streamable-http", "sse"], + help="Transport to use (defaults to environment or stdio).", + ) + parser.add_argument("--host", help="HTTP host to bind when using HTTP transports.") + parser.add_argument( + "--port", + type=int, + help="HTTP port to bind when using HTTP transports.", + ) + parser.add_argument( + "--path", + help="Optional HTTP path when using streamable transports.", + ) + args = parser.parse_args() + + selected = (args.transport or os.environ.get("GALAXY_MCP_TRANSPORT") or "stdio").lower() + if selected in {"streamable-http", "sse"}: + server.run_http_server( + host=args.host, + port=args.port, + transport=selected, + path=args.path, + ) + else: + server.mcp.run() if __name__ == "__main__": diff --git a/mcp-server-galaxy-py/src/galaxy_mcp/auth.py b/mcp-server-galaxy-py/src/galaxy_mcp/auth.py new file mode 100644 index 0000000..87fc487 --- /dev/null +++ b/mcp-server-galaxy-py/src/galaxy_mcp/auth.py @@ -0,0 +1,892 @@ +"""Authentication helpers and stateless OAuth provider for the Galaxy MCP server.""" + +from __future__ import annotations + +import base64 +import hashlib +import inspect +import json +import logging +import secrets +import textwrap +import time +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import anyio +import requests +from bioblend.galaxy import GalaxyInstance +from cryptography.fernet import Fernet, InvalidToken +from fastmcp.server.auth.auth import ( + AccessToken as FastMCPAccessToken, +) +from fastmcp.server.auth.auth import ( + ClientRegistrationOptions, + OAuthProvider, + RevocationOptions, +) +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + RefreshToken, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +from starlette.requests import Request +from starlette.responses import ( + HTMLResponse, + JSONResponse, + PlainTextResponse, + RedirectResponse, + Response, +) +from starlette.routing import Route + +try: # pragma: no cover - fallback import for Python < 3.12 + from typing import override +except ImportError: # pragma: no cover - Python < 3.12 without typing.override + try: + from typing_extensions import override # type: ignore + except ImportError: # pragma: no cover - hard fallback if typing_extensions missing + + def override(func): # type: ignore + return func + + +logger = logging.getLogger(__name__) + +AUTH_CODE_TTL_SECONDS = 5 * 60 +ACCESS_TOKEN_TTL_SECONDS = 60 * 60 +REFRESH_TOKEN_TTL_SECONDS = 7 * 24 * 60 * 60 + +LOGIN_PATH = "/galaxy-auth/login" +RESOURCE_METADATA_PATH = "/.well-known/oauth-protected-resource" + +CHATGPT_LOGO_DATA_URI = ( + "data:image/svg+xml;base64," + "PHN2ZyB3aWR0aD0iNzIiIGhlaWdodD0iNzIiIHZpZXdCb3g9IjAgMCA3MiA3MiIgZmlsbD0ibm9u" + "ZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHJlY3Qgd2lkdGg9IjcyIiBo" + "ZWlnaHQ9IjcyIiByeD0iMTgiIGZpbGw9IiMxMDEzMUYiLz4KPHBhdGggZD0iTTM2IDE3QzMwLjQ3" + "NyAxNyAyNS42MjEgMjAuNTY3IDI0LjAwMSAyNS42NzZDMjAuOTI3IDI2Ljc4OCAxOC41IDI5Ljg1" + "NSAxOC41IDMzLjVDMTguNSAzNy42NDIgMjEuODU4IDQxIDI2IDQxSDI3LjAzQzI3LjAxMSA0MS4z" + "MzIgMjcgNDEuNjY1IDI3IDQyQzI3IDQ4LjA3NSAzMS45MjUgNTMgMzggNTNDNDMuNTIzIDUzIDQ4" + "LjM3OSA0OS40MzMgNDkuOTk5IDQ0LjMyNEM1My4wNzMgNDMuMjEyIDU1LjUgNDAuMTQ1IDU1LjUg" + "MzYuNUM1NS41IDMyLjM1OCA1Mi4xNDIgMjkgNDggMjlINDYuOTdDNDYuOTg5IDI4LjY2OCA0NyAy" + "OC4zMzUgNDcgMjhDNDcgMjEuOTI1IDQyLjA3NSAxNyAzNiAxN1oiIHN0cm9rZT0iIzczRkZBOSIgc" + "3Ryb2tlLXdpZHRoPSIzLjIiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIgc3Ryb2tlLWxpbmVqb2luPS" + "Jyb3VuZCIvPgo8cGF0aCBkPSJNMjggNDFDMjguNTM3IDQ3LjUzOSAzNC4wMDMgNTMgNDAuNyA1MyI" + "gc3Ryb2tlPSIjNENFMEIzIiBzdHJva2Utd2lkdGg9IjIuNCIgc3Ryb2tlLWxpbmVjYXA9InJvdW5k" + "Ii8+CjxwYXRoIGQ9Ik00NCAyOUM0My40NjMgMjIuNDYxIDM3Ljk5NyAxNyAzMS4zIDE3IiBzdHJva" + "2U9IiM0Q0UwQjMiIHN0cm9rZS13aWR0aD0iMi40IiBzdHJva2UtbGluZWNhcD0icm91bmQiLz4KPC" + "9zdmc+" +) + +GALAXY_LOGO_DATA_URI = ( + "data:image/svg+xml;base64," + "PHN2ZyB3aWR0aD0iNzIiIGhlaWdodD0iNzIiIHZpZXdCb3g9IjAgMCA3MiA3MiIgZmlsbD0ibm9u" + "ZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHJlY3Qgd2lkdGg9IjcyIiBo" + "ZWlnaHQ9IjcyIiByeD0iMTgiIGZpbGw9InVybCgjZykiLz4KPGRlZnM+CjxsaW5lYXJHcmFkaWVu" + "dCBpZD0iZyIgeDE9IjE0IiB5MT0iMTIiIHgyPSI1OCIgeTI9IjYwIiBncmFkaWVudFVuaXRzPSJ1" + "c2VyU3BhY2VPblVzZSI+CjxzdG9wIHN0b3AtY29sb3I9IiMxMjE4M0YiLz4KPHN0b3Agb2Zmc2V0" + "PSIxIiBzdG9wLWNvbG9yPSIjMkY2REZGIi8+CjwvbGluZWFyR3JhZGllbnQ+CjwvZGVmcz4KPHBh" + "dGggZD0iTTM2IDU3QzQ0LjI4NCA1NyA1MSA0MSAzNiA0MUMyMSA0MSAyNy43MTYgNTcgMzYgNTda" + "IiBmaWxsPSIjRkNEMzREIiBmaWxsLW9wYWNpdHk9IjAuODUiLz4KPHBhdGggZD0iTTE0IDMxQzI2" + "LjUgMzEgMzIuNSA0MSAzMS41IDQ3QzMwLjUgNTMgMTkgNTcgMTQgNTciIHN0cm9rZT0iI0Y5RjdG" + "RiIgc3Ryb2tlLXdpZHRoPSIyLjYiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8cGF0aCBkPSJN" + "NTggMzFDNDUuNSAzMSAzOS41IDQxIDQwLjUgNDdDNDEuNSA1MyA1MyA1NyA1OCA1NyIgc3Ryb2tl" + "PSIjQzBEQUZGIiBzdHJva2Utd2lkdGg9IjIuNiIgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIi8+Cjxj" + "aXJjbGUgY3g9IjIxIiBjeT0iMjMiIHI9IjMiIGZpbGw9IiNGQ0QzNEQiLz4KPGNpcmNsZSBjeD0i" + "NTEiIGN5PSIxOSIgcj0iMi41IiBmaWxsPSIjQzBEQUZGIi8+Cjwvc3ZnPg==" +) + + +class GalaxyAuthenticationError(Exception): + """Raised when Galaxy authentication fails.""" + + +@dataclass +class AuthorizationTransaction: + """Stored data for an in-flight authorization request.""" + + client_id: str + redirect_uri: str + redirect_uri_provided_explicitly: bool + state: str | None + code_challenge: str + code_challenge_method: str + scopes: list[str] + created_at: float + + +@dataclass(frozen=True) +class GalaxyCredentials: + """Decoded Galaxy credentials from an access token.""" + + galaxy_url: str + api_key: str + username: str + user_email: str | None + expires_at: int + scopes: list[str] + client_id: str + + +class GalaxyOAuthProvider(OAuthProvider): + """OAuth provider that authenticates users against a Galaxy instance.""" + + def __init__( + self, + *, + base_url: str, + galaxy_url: str, + required_scopes: list[str] | None = None, + session_secret: str | None = None, + client_registry_path: str | Path | None = None, + ): + client_registration = ClientRegistrationOptions(enabled=True) + revocation_options = RevocationOptions(enabled=True) + + normalized_base_url = base_url.rstrip("/") + if not normalized_base_url: + raise ValueError("base_url must be a non-empty string") + + super_init = super().__init__ + super_params = inspect.signature(super_init).parameters + super_kwargs: dict[str, Any] = {} + if "base_url" in super_params: + super_kwargs["base_url"] = normalized_base_url + if "issuer_url" in super_params: + super_kwargs["issuer_url"] = normalized_base_url + if "service_documentation_url" in super_params: + super_kwargs["service_documentation_url"] = None + if "client_registration_options" in super_params: + super_kwargs["client_registration_options"] = client_registration + if "revocation_options" in super_params: + super_kwargs["revocation_options"] = revocation_options + if "required_scopes" in super_params: + super_kwargs["required_scopes"] = required_scopes or ["galaxy:full"] + + super_init(**super_kwargs) + + self.base_url = normalized_base_url + self.required_scopes = required_scopes or ["galaxy:full"] + self._galaxy_url = galaxy_url if galaxy_url.endswith("/") else f"{galaxy_url}/" + self._transactions: dict[str, AuthorizationTransaction] = {} + self._clients: dict[str, OAuthClientInformationFull] = {} + self._fernet = Fernet(self._derive_key(session_secret)) + self._client_registry_path = ( + Path(client_registry_path).expanduser() if client_registry_path else None + ) + + self._load_client_registry() + + # ------------------------------------------------------------------ + # OAuth provider interface + # ------------------------------------------------------------------ + + @override + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self._clients.get(client_id) + + @override + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + self._clients[client_info.client_id] = client_info + await self._persist_client_registry() + + @override + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + txn_id = secrets.token_urlsafe(32) + transaction = AuthorizationTransaction( + client_id=client.client_id, + redirect_uri=str(params.redirect_uri), + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + state=params.state, + code_challenge=params.code_challenge, + code_challenge_method=getattr(params, "code_challenge_method", "S256"), + scopes=params.scopes or self.required_scopes, + created_at=time.time(), + ) + self._transactions[txn_id] = transaction + + base_url = str(self.base_url) + login_url = construct_redirect_uri( + f"{base_url.rstrip('/')}{LOGIN_PATH}", + txn=txn_id, + galaxy=self._galaxy_url.rstrip("/"), + ) + logger.debug("Created authorization transaction %s for client %s", txn_id, client.client_id) + return login_url + + @override + async def load_authorization_code( + self, + client: OAuthClientInformationFull, + authorization_code: str, + ) -> AuthorizationCode | None: + try: + payload = self._decrypt_payload(authorization_code, expected_type="authorization_code") + except InvalidToken: + return None + + if payload["client_id"] != client.client_id: + return None + + if payload["exp"] < time.time(): + return None + + return AuthorizationCode( + code=authorization_code, + client_id=client.client_id, + scopes=payload["scopes"], + expires_at=payload["exp"], + code_challenge=payload["code_challenge"], + redirect_uri=payload["redirect_uri"], + redirect_uri_provided_explicitly=payload["redirect_uri_provided_explicitly"], + resource=None, + ) + + @override + async def exchange_authorization_code( + self, + client: OAuthClientInformationFull, + authorization_code: AuthorizationCode, + ) -> OAuthToken: + payload = self._decrypt_payload(authorization_code.code, expected_type="authorization_code") + if payload["exp"] < time.time(): + raise GalaxyAuthenticationError("Authorization code expired.") + + if payload["client_id"] != client.client_id: + raise GalaxyAuthenticationError("Authorization code issued for a different client.") + + galaxy_payload = payload["galaxy"] + return self._issue_tokens( + client_id=client.client_id, scopes=payload["scopes"], galaxy_payload=galaxy_payload + ) + + @override + async def load_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: str, + ) -> RefreshToken | None: + try: + payload = self._decrypt_payload(refresh_token, expected_type="refresh") + except InvalidToken: + return None + + if payload["client_id"] != client.client_id: + return None + if payload["exp"] < time.time(): + return None + + return RefreshToken( + token=refresh_token, + client_id=payload["client_id"], + scopes=payload["scopes"], + expires_at=payload["exp"], + ) + + @override + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + payload = self._decrypt_payload(refresh_token.token, expected_type="refresh") + if payload["exp"] < time.time(): + raise GalaxyAuthenticationError("Refresh token expired.") + + if payload["client_id"] != client.client_id: + raise GalaxyAuthenticationError("Refresh token issued for a different client.") + + resolved_scopes = scopes or payload["scopes"] + return self._issue_tokens( + client_id=client.client_id, scopes=resolved_scopes, galaxy_payload=payload["galaxy"] + ) + + @override + async def load_access_token(self, token: str) -> AccessToken | None: + try: + payload = self._decrypt_payload(token, expected_type="access") + except InvalidToken: + return None + + if payload["exp"] < time.time(): + return None + + galaxy_info = payload["galaxy"] + return FastMCPAccessToken( + token=token, + client_id=payload["client_id"], + scopes=payload["scopes"], + expires_at=payload["exp"], + claims={ + "galaxy_url": galaxy_info["url"], + "username": galaxy_info["username"], + "user_email": galaxy_info.get("user_email"), + }, + ) + + @override + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + # Stateless tokens cannot be selectively revoked without external storage. + logger.debug( + "Revocation requested for token, but stateless tokens cannot be revoked individually." + ) + + # ------------------------------------------------------------------ + # Integration helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_base_path(base_path: str | None) -> str | None: + if not base_path: + return None + + normalized = base_path if base_path.startswith("/") else f"/{base_path}" + normalized = normalized.rstrip("/") + if not normalized or normalized == "/": + return None + return normalized + + def get_login_paths(self, base_path: str | None = None) -> set[str]: + login_paths = {LOGIN_PATH} + normalized = self._normalize_base_path(base_path) + if normalized: + login_paths.add(f"{normalized}{LOGIN_PATH}") + return login_paths + + def get_resource_metadata_paths(self, base_path: str | None = None) -> set[str]: + metadata_paths = {RESOURCE_METADATA_PATH} + normalized = self._normalize_base_path(base_path) + if normalized: + metadata_paths.add(f"{normalized}{RESOURCE_METADATA_PATH}") + return metadata_paths + + @override + async def handle_login(self, request: Request) -> Response: + """Public wrapper for the login handler so it can be registered on FastMCP routes.""" + return await self._login_handler(request) + + @override + def get_resource_metadata(self) -> dict[str, Any]: + """Return OAuth protected resource metadata.""" + return { + "resource": self._galaxy_url, + "authorization_servers": [self.base_url], + "scopes_supported": self.required_scopes, + "token_types_supported": ["Bearer"], + } + + @override + async def handle_resource_metadata(self, request: Request) -> Response: + """Return OAuth protected resource metadata.""" + return JSONResponse(self.get_resource_metadata()) + + @override + def get_routes( + self, mcp_path: str | None = None, mcp_endpoint: Any | None = None + ) -> list[Route]: + """ + Return the Starlette routes that expose the OAuth surface. + + FastMCP's base OAuth implementation eagerly registers generic login/resource-metadata + endpoints. Galaxy-specific login handling needs to replace those with custom handlers + (including base-path-prefixed variants), so we strip the parent definitions first and + then install our own. This defensive dedupe also shields us from future FastMCP routing + changes that might otherwise create duplicate routes and confusing behaviour. + """ + routes = super().get_routes(mcp_path, mcp_endpoint) + + base_path = self._normalize_base_path( + urlparse(str(self.base_url)).path if self.base_url else None + ) + login_paths = self.get_login_paths(base_path) + metadata_paths = self.get_resource_metadata_paths(base_path) + + routes = [ + route + for route in routes + if not (isinstance(route, Route) and route.path in login_paths | metadata_paths) + ] + + existing_paths = {route.path for route in routes if isinstance(route, Route)} + + for path in login_paths: + if path not in existing_paths: + routes.append(Route(path, endpoint=self.handle_login, methods=["GET", "POST"])) + existing_paths.add(path) + + for path in metadata_paths: + if path not in existing_paths: + routes.append(Route(path, endpoint=self.handle_resource_metadata, methods=["GET"])) + existing_paths.add(path) + + return routes + + def _load_client_registry(self) -> None: + if not self._client_registry_path: + return + + path = self._client_registry_path + try: + if not path.exists(): + return + + raw = path.read_text(encoding="utf-8") + if not raw.strip(): + return + payload = json.loads(raw) + if not isinstance(payload, list): + logger.warning("Client registry at %s is not a list; ignoring contents.", path) + return + + for entry in payload: + try: + client = OAuthClientInformationFull.model_validate(entry) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to load client entry from registry: %s", exc) + continue + self._clients[client.client_id] = client + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to load client registry from %s: %s", path, exc) + + async def _persist_client_registry(self) -> None: + if not self._client_registry_path: + return + + path = self._client_registry_path + clients_data = [ + client.model_dump(mode="json") + for client in sorted(self._clients.values(), key=lambda c: c.client_id) + ] + + def _write() -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as fh: + json.dump(clients_data, fh, separators=(",", ":"), sort_keys=True) + tmp_path.replace(path) + + try: + await anyio.to_thread.run_sync(_write, abandon_on_cancel=True) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to persist client registry to %s: %s", path, exc) + + def decode_access_token(self, token: str) -> dict[str, Any] | None: + try: + payload = self._decrypt_payload(token, expected_type="access") + except InvalidToken: + return None + + if payload["exp"] < time.time(): + return None + + return payload + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _derive_key(self, secret: str | None) -> bytes: + if secret: + digest = hashlib.sha256(secret.encode("utf-8")).digest() + return base64.urlsafe_b64encode(digest) + key = Fernet.generate_key() + logger.warning( + "GALAXY_MCP_SESSION_SECRET is not set; generated a volatile secret. " + "All tokens will become invalid on restart." + ) + return key + + def _encrypt_payload(self, payload: dict[str, Any]) -> str: + serialized = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") + return self._fernet.encrypt(serialized).decode("utf-8") + + def _decrypt_payload(self, token: str, *, expected_type: str) -> dict[str, Any]: + data = self._fernet.decrypt(token.encode("utf-8")) + payload: dict[str, Any] = json.loads(data.decode("utf-8")) + if payload.get("typ") != expected_type: + raise InvalidToken("Token type mismatch") + return payload + + def _issue_tokens( + self, *, client_id: str, scopes: list[str], galaxy_payload: dict[str, Any] + ) -> OAuthToken: + now = int(time.time()) + access_payload = { + "typ": "access", + "client_id": client_id, + "scopes": scopes, + "galaxy": galaxy_payload, + "exp": now + ACCESS_TOKEN_TTL_SECONDS, + "iat": now, + "nonce": secrets.token_urlsafe(8), + } + + refresh_payload = { + "typ": "refresh", + "client_id": client_id, + "scopes": scopes, + "galaxy": galaxy_payload, + "exp": now + REFRESH_TOKEN_TTL_SECONDS, + "iat": now, + "nonce": secrets.token_urlsafe(8), + } + + access_token = self._encrypt_payload(access_payload) + refresh_token = self._encrypt_payload(refresh_payload) + + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=ACCESS_TOKEN_TTL_SECONDS, + refresh_token=refresh_token, + scope=" ".join(scopes), + ) + + async def _login_handler(self, request: Request) -> Response: + txn_id = request.query_params.get("txn") + if not txn_id: + return PlainTextResponse("Missing transaction identifier.", status_code=400) + + transaction = self._transactions.get(txn_id) + if not transaction: + return PlainTextResponse("Authorization request is no longer valid.", status_code=400) + + if request.method == "GET": + return self._render_login_form(transaction, error=request.query_params.get("error")) + + form = await request.form() + username = (form.get("username") or "").strip() + password = (form.get("password") or "").strip() + + if not username or not password: + return self._render_login_form(transaction, error="Username and password are required.") + + try: + redirect_url = await self._authenticate_and_complete(txn_id, username, password) + except GalaxyAuthenticationError as exc: + return self._render_login_form(transaction, error=str(exc)) + except Exception as exc: # pragma: no cover - defensive + logger.exception("Unexpected error during Galaxy login: %s", exc) + return self._render_login_form( + transaction, error="Unexpected error during authentication." + ) + + return RedirectResponse(redirect_url, status_code=303) + + async def _authenticate_and_complete(self, txn_id: str, username: str, password: str) -> str: + transaction = self._transactions.pop(txn_id, None) + if not transaction: + raise GalaxyAuthenticationError( + "Authorization request expired. Please restart the flow." + ) + + api_key = await self._get_api_key(username, password) + user_info = await self._get_user_info(api_key) + + galaxy_payload = { + "url": self._galaxy_url, + "api_key": api_key, + "username": user_info.get("username") or username, + "user_email": user_info.get("email"), + } + + code_payload = { + "typ": "authorization_code", + "client_id": transaction.client_id, + "scopes": transaction.scopes, + "code_challenge": transaction.code_challenge, + "code_challenge_method": transaction.code_challenge_method, + "redirect_uri": transaction.redirect_uri, + "redirect_uri_provided_explicitly": transaction.redirect_uri_provided_explicitly, + "galaxy": galaxy_payload, + "exp": time.time() + AUTH_CODE_TTL_SECONDS, + "nonce": secrets.token_urlsafe(8), + } + + code_value = self._encrypt_payload(code_payload) + + logger.info("Galaxy authentication successful for user %s", galaxy_payload["username"]) + return construct_redirect_uri( + transaction.redirect_uri, code=code_value, state=transaction.state + ) + + async def _get_api_key(self, username: str, password: str) -> str: + url = f"{self._galaxy_url}api/authenticate/baseauth" + + def _request_api_key() -> str: + response = requests.get(url, auth=(username, password), timeout=15) + if response.status_code == 401: + raise GalaxyAuthenticationError("Invalid Galaxy credentials.") + response.raise_for_status() + payload = response.json() + key = payload.get("api_key") + if not key: + raise GalaxyAuthenticationError("Galaxy did not return an API key.") + return key + + return await anyio.to_thread.run_sync(_request_api_key) + + async def _get_user_info(self, api_key: str) -> dict[str, Any]: + def _fetch() -> dict[str, Any]: + gi = GalaxyInstance(url=self._galaxy_url, key=api_key) + return gi.users.get_current_user() + + try: + return await anyio.to_thread.run_sync(_fetch) + except Exception as exc: + raise GalaxyAuthenticationError("Failed to validate API key with Galaxy.") from exc + + def _render_login_form( + self, transaction: AuthorizationTransaction, error: str | None = None + ) -> HTMLResponse: + scopes_text = ", ".join(transaction.scopes) if transaction.scopes else "galaxy:full" + error_html = f'' if error else "" + html = textwrap.dedent( + f""" + + + + + Authorize ChatGPT for Galaxy + + + +
+ +

Allow ChatGPT to access Galaxy

+

Sign in to {self._galaxy_url.rstrip('/')}.

+

Scopes: {scopes_text}

+ {error_html} +
+
+ + +
+
+ + +
+ +
+

+ Need help? + + usegalaxy.org.au + +

+
+ + + """ + ) + return HTMLResponse(html) + + +_AUTH_PROVIDER: GalaxyOAuthProvider | None = None + + +def configure_auth_provider(provider: GalaxyOAuthProvider) -> None: + """Register the global auth provider instance.""" + global _AUTH_PROVIDER + _AUTH_PROVIDER = provider + + +def get_auth_provider() -> GalaxyOAuthProvider | None: + """Return the configured auth provider, if any.""" + return _AUTH_PROVIDER + + +def get_active_session( + get_token: Callable[[], AccessToken | None], +) -> tuple[GalaxyCredentials | None, str | None]: + """Decode the access token from the request and extract Galaxy credentials.""" + provider = get_auth_provider() + if not provider: + return None, None + + access_token = get_token() + if access_token is None: + return None, None + + token_payload = provider.decode_access_token(access_token.token) + if not token_payload: + return None, None + + galaxy_payload = token_payload["galaxy"] + credentials = GalaxyCredentials( + galaxy_url=galaxy_payload["url"], + api_key=galaxy_payload["api_key"], + username=galaxy_payload["username"], + user_email=galaxy_payload.get("user_email"), + expires_at=token_payload["exp"], + scopes=token_payload["scopes"], + client_id=token_payload["client_id"], + ) + return credentials, credentials.api_key diff --git a/mcp-server-galaxy-py/src/galaxy_mcp/server.py b/mcp-server-galaxy-py/src/galaxy_mcp/server.py index ac7d743..8be8a9e 100644 --- a/mcp-server-galaxy-py/src/galaxy_mcp/server.py +++ b/mcp-server-galaxy-py/src/galaxy_mcp/server.py @@ -3,12 +3,24 @@ import logging import os import threading +import types +from pathlib import Path from typing import Any import requests from bioblend.galaxy import GalaxyInstance from dotenv import find_dotenv, load_dotenv from fastmcp import FastMCP +from mcp.server.auth.middleware.auth_context import get_access_token +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from galaxy_mcp.auth import ( + GalaxyOAuthProvider, + configure_auth_provider, + get_active_session, +) # Set up logging logging.basicConfig(level=logging.INFO) @@ -46,41 +58,195 @@ def format_error(action: str, error: Exception, context: dict | None = None) -> load_dotenv(dotenv_path) print(f"Loaded environment variables from {dotenv_path}") -# Create an MCP server -mcp: FastMCP = FastMCP("Galaxy") - -# Galaxy client state +# Configure Galaxy target and client state +raw_galaxy_url = os.environ.get("GALAXY_URL") +normalized_galaxy_url = ( + raw_galaxy_url if not raw_galaxy_url or raw_galaxy_url.endswith("/") else f"{raw_galaxy_url}/" +) galaxy_state: dict[str, Any] = { - "url": os.environ.get("GALAXY_URL"), + "url": normalized_galaxy_url, "api_key": os.environ.get("GALAXY_API_KEY"), "gi": None, "connected": False, } +# Configure OAuth provider if requested +public_base_url = os.environ.get("GALAXY_MCP_PUBLIC_URL") +session_secret = os.environ.get("GALAXY_MCP_SESSION_SECRET") +client_registry_path_env = os.environ.get("GALAXY_MCP_CLIENT_REGISTRY") +default_registry_path = Path.home() / ".galaxy-mcp" / "clients.json" +client_registry_path = ( + Path(client_registry_path_env).expanduser() + if client_registry_path_env + else default_registry_path +) +auth_provider: GalaxyOAuthProvider | None = None +if public_base_url and normalized_galaxy_url: + try: + auth_provider = GalaxyOAuthProvider( + base_url=public_base_url, + galaxy_url=normalized_galaxy_url, + session_secret=session_secret, + client_registry_path=client_registry_path, + ) + configure_auth_provider(auth_provider) + logger.info("OAuth login enabled for Galaxy at %s", normalized_galaxy_url) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to initialize OAuth provider: %s", exc, exc_info=True) +elif public_base_url and not normalized_galaxy_url: + logger.warning( + "GALAXY_MCP_PUBLIC_URL is set but GALAXY_URL is missing. " + "OAuth login remains disabled until GALAXY_URL is configured." + ) +else: + logger.info( + "OAuth login disabled. Configure GALAXY_MCP_PUBLIC_URL to enable browser-based login." + ) + +# Create an MCP server (inject auth provider when available) +if auth_provider: + mcp: FastMCP = FastMCP("Galaxy", auth=auth_provider) +else: + mcp = FastMCP("Galaxy") + +# Allow browser preflight CORS requests to bypass FastMCP auth + + +class _PreflightMiddleware(BaseHTTPMiddleware): + """Ensure CORS preflight requests succeed for browser-based clients.""" + + async def dispatch(self, request, call_next): + origin = request.headers.get("origin", "*") + allow_methods = request.headers.get("access-control-request-method", "POST,GET,OPTIONS") + allow_headers = request.headers.get( + "access-control-request-headers", "authorization,content-type" + ) + + cors_headers = { + "access-control-allow-origin": origin, + "access-control-allow-methods": allow_methods, + "access-control-allow-headers": allow_headers, + "access-control-max-age": "600", + } + + if request.method.upper() == "OPTIONS": + return Response(status_code=204, headers=cors_headers) + + response = await call_next(request) + for header, value in cors_headers.items(): + response.headers.setdefault(header, value) + return response + + +_original_http_app = FastMCP.http_app + + +class _OAuthPublicRoutes: + """Expose OAuth login and metadata routes without auth headers.""" + + def __init__(self, app, provider: GalaxyOAuthProvider, base_path: str | None): + self._app = app + self._provider = provider + self._login_paths = provider.get_login_paths(base_path) + self._metadata_paths = provider.get_resource_metadata_paths(base_path) + self.state = getattr(app, "state", None) + self.router = getattr(app, "router", None) + + def __getattr__(self, item): + return getattr(self._app, item) + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self._app(scope, receive, send) + return + + path = scope.get("path", "") + method = scope.get("method", "").upper() + if path in self._metadata_paths: + if method not in {"GET", "HEAD"}: + await self._app(scope, receive, send) + return + request = Request(scope, receive=receive) + response = await self._provider.handle_resource_metadata(request) + await response(scope, receive, send) + return + + if path in self._login_paths and method in {"GET", "POST"}: + request = Request(scope, receive=receive) + response = await self._provider.handle_login(request) + await response(scope, receive, send) + return + + await self._app(scope, receive, send) + + +def _http_app_with_preflight(self, *args, **kwargs): + app = _original_http_app(self, *args, **kwargs) + app.add_middleware(_PreflightMiddleware) + if auth_provider: + base_path = kwargs.get("path") + app = _OAuthPublicRoutes(app, auth_provider, base_path) + return app + + +mcp.http_app = types.MethodType(_http_app_with_preflight, mcp) + # Initialize Galaxy client if environment variables are set if galaxy_state["url"] and galaxy_state["api_key"]: try: - galaxy_url = ( - galaxy_state["url"] if galaxy_state["url"].endswith("/") else f"{galaxy_state['url']}/" - ) - galaxy_state["url"] = galaxy_url - galaxy_state["gi"] = GalaxyInstance(url=galaxy_url, key=galaxy_state["api_key"]) + galaxy_state["gi"] = GalaxyInstance(url=galaxy_state["url"], key=galaxy_state["api_key"]) galaxy_state["connected"] = True - logger.info(f"Galaxy client initialized from environment variables (URL: {galaxy_url})") + logger.info( + "Galaxy client initialized from environment variables (URL: %s)", + galaxy_state["url"], + ) except Exception as e: logger.warning(f"Failed to initialize Galaxy client from environment variables: {e}") logger.warning("You'll need to use connect() to establish a connection.") -def ensure_connected(): - """Helper function to ensure Galaxy connection is established""" - if not galaxy_state["connected"] or not galaxy_state["gi"]: +def _get_request_connection_state() -> dict[str, Any]: + """ + Determine the effective Galaxy connection, preferring OAuth credentials when available. + """ + if auth_provider: + credentials, api_key = get_active_session(get_access_token) + if credentials and api_key: + try: + gi = GalaxyInstance(url=credentials.galaxy_url, key=api_key) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to create Galaxy client for OAuth session: %s", exc) + else: + return { + "url": credentials.galaxy_url, + "api_key": api_key, + "gi": gi, + "connected": True, + "source": "oauth", + "session": credentials, + } + + return { + "url": galaxy_state.get("url") or normalized_galaxy_url, + "api_key": galaxy_state.get("api_key"), + "gi": galaxy_state.get("gi"), + "connected": galaxy_state.get("connected", False) and bool(galaxy_state.get("gi")), + "source": "api_key" if galaxy_state.get("connected") else None, + "session": None, + } + + +def ensure_connected() -> dict[str, Any]: + """Helper function to ensure Galaxy connection is established.""" + state = _get_request_connection_state() + if not state["connected"] or not state["gi"]: raise ValueError( - "Not connected to Galaxy. " - "Please run connect() first with your Galaxy URL and API key. " - "Example: connect(url='https://your-galaxy.org', api_key='your-key')" + "Not connected to Galaxy. Authenticate via OAuth or run connect() with your " + "Galaxy URL and API key. Example: connect(url='https://your-galaxy.org', " + "api_key='your-key')" ) + return state @mcp.tool() @@ -96,6 +262,18 @@ def connect(url: str | None = None, api_key: str | None = None) -> dict[str, Any Connection status and user information """ try: + # Reuse current OAuth session when available + state = _get_request_connection_state() + if state["connected"] and state.get("source") == "oauth" and state["gi"]: + gi: GalaxyInstance = state["gi"] + user_info = gi.users.get_current_user() + return { + "connected": True, + "user": user_info, + "url": state["url"], + "auth": "oauth", + } + # Use provided parameters or fall back to environment variables use_url = url or os.environ.get("GALAXY_URL") use_api_key = api_key or os.environ.get("GALAXY_API_KEY") @@ -146,6 +324,7 @@ def connect(url: str | None = None, api_key: str | None = None) -> dict[str, Any galaxy_state["gi"] = None galaxy_state["connected"] = False + galaxy_url = locals().get("galaxy_url") or use_url or normalized_galaxy_url or "unknown" error_msg = f"Failed to connect to Galaxy at {galaxy_url}: {str(e)}" if "401" in str(e) or "authentication" in str(e).lower(): error_msg += " Check that your API key is valid and has the necessary permissions." @@ -170,11 +349,12 @@ def search_tools(query: str) -> dict[str, Any]: Returns: List of tools matching the query """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # The get_tools method is used with name filter parameter - tools = galaxy_state["gi"].tools.get_tools(name=query) + tools = gi.tools.get_tools(name=query) return {"tools": tools} except Exception as e: raise ValueError( @@ -195,11 +375,12 @@ def get_tool_details(tool_id: str, io_details: bool = False) -> dict[str, Any]: Returns: Tool details """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get detailed information about the tool - tool_info = galaxy_state["gi"].tools.show_tool(tool_id, io_details=io_details) + tool_info = gi.tools.show_tool(tool_id, io_details=io_details) return tool_info except Exception as e: raise ValueError( @@ -220,11 +401,12 @@ def get_tool_citations(tool_id: str) -> dict[str, Any]: Returns: Tool citation information """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get the tool information which includes citations - tool_info = galaxy_state["gi"].tools.show_tool(tool_id) + tool_info = gi.tools.show_tool(tool_id) # Extract citation information citations = tool_info.get("citations", []) @@ -253,11 +435,12 @@ def run_tool(history_id: str, tool_id: str, inputs: dict[str, Any]) -> dict[str, Returns: Dictionary containing tool execution information including job IDs and output dataset IDs """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Run the tool with provided inputs - result = galaxy_state["gi"].tools.run_tool(history_id, tool_id, inputs) + result = gi.tools.run_tool(history_id, tool_id, inputs) return result except Exception as e: error_msg = f"Failed to run tool '{tool_id}' in history '{history_id}': {str(e)}" @@ -279,11 +462,12 @@ def get_tool_panel() -> dict[str, Any]: Returns: Tool panel hierarchy """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get the tool panel structure - tool_panel = galaxy_state["gi"].tools.get_tool_panel() + tool_panel = gi.tools.get_tool_panel() return {"tool_panel": tool_panel} except Exception as e: raise ValueError(f"Failed to get tool panel: {str(e)}") from e @@ -300,8 +484,9 @@ def create_history(history_name: str) -> dict[str, Any]: Returns: Dictionary containing the created history details including the new history ID hash """ - ensure_connected() - return galaxy_state["gi"].histories.create_history(history_name) + state = ensure_connected() + gi: GalaxyInstance = state["gi"] + return gi.histories.create_history(history_name) @mcp.tool() @@ -318,14 +503,15 @@ def filter_tools_by_dataset(dataset_type: list[str]) -> dict[str, Any]: dict: A dictionary containing the list of recommended tools and the total count. """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] lock = threading.Lock() dataset_keywords = [dt.lower() for dt in dataset_type] try: - tool_panel = galaxy_state["gi"].tools.get_tool_panel() + tool_panel = gi.tools.get_tool_panel() def flatten_tools(panel): tools = [] @@ -364,7 +550,7 @@ def check_tool(tool): if tool_id.endswith("_label"): return None try: - tool_details = galaxy_state["gi"].tools.show_tool(tool_id, io_details=True) + tool_details = gi.tools.show_tool(tool_id, io_details=True) tool_inputs = tool_details.get("inputs", [{}]) for input_spec in tool_inputs: if not isinstance(input_spec, dict): @@ -418,18 +604,20 @@ def get_server_info() -> dict[str, Any]: Returns: Server information including version, URL, and other configuration details """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] + url = state["url"] or normalized_galaxy_url try: # Get server configuration info - config_info = galaxy_state["gi"].config.get_config() + config_info = gi.config.get_config() # Get server version info - version_info = galaxy_state["gi"].config.get_version() + version_info = gi.config.get_version() # Build comprehensive server info response server_info = { - "url": galaxy_state["url"], + "url": url, "version": version_info, "config": { "brand": config_info.get("brand", "Galaxy"), @@ -466,10 +654,11 @@ def get_user() -> dict[str, Any]: Returns: Current user details """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: - user_info = galaxy_state["gi"].users.get_current_user() + user_info = gi.users.get_current_user() return user_info except Exception as e: raise ValueError(f"Failed to get user: {str(e)}") from e @@ -490,18 +679,17 @@ def get_histories( Returns: Dictionary containing list of histories and pagination metadata """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get histories with pagination and optional filtering - histories = galaxy_state["gi"].histories.get_histories( - limit=limit, offset=offset, name=name - ) + histories = gi.histories.get_histories(limit=limit, offset=offset, name=name) # If pagination is used, get total count for metadata if limit is not None: # Get total count without pagination - all_histories = galaxy_state["gi"].histories.get_histories(name=name) + all_histories = gi.histories.get_histories(name=name) total_items = len(all_histories) if all_histories else 0 # Calculate pagination metadata @@ -553,10 +741,11 @@ def list_history_ids() -> list[dict[str, str]]: Returns: List of dictionaries containing 'id' and 'name' fields """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: - histories = galaxy_state["gi"].histories.get_histories() + histories = gi.histories.get_histories() if not histories: return [] # Extract just the id and name for convenience @@ -587,17 +776,18 @@ def get_history_details(history_id: str) -> dict[str, Any]: To get actual datasets: Use get_history_contents(history_id, limit=N, order="create_time-dsc") """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: logger.info(f"Getting details for history ID: {history_id}") # Get history details - history_info = galaxy_state["gi"].histories.show_history(history_id, contents=False) + history_info = gi.histories.show_history(history_id, contents=False) logger.info(f"Successfully retrieved history info: {history_info.get('name', 'Unknown')}") # Get total count by calling without limit - all_contents = galaxy_state["gi"].histories.show_history(history_id, contents=True) + all_contents = gi.histories.show_history(history_id, contents=True) total_items = len(all_contents) if all_contents else 0 return { @@ -649,7 +839,8 @@ def get_history_contents( Returns: Dictionary containing paginated dataset list, pagination metadata, and history reference """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: logger.info( @@ -658,7 +849,7 @@ def get_history_contents( ) # Use datasets API for better ordering support - contents = galaxy_state["gi"].datasets.get_datasets( + contents = gi.datasets.get_datasets( limit=limit, offset=offset, history_id=history_id, @@ -674,7 +865,7 @@ def get_history_contents( contents = [item for item in contents if item.get("visible", True)] # Get total count for pagination metadata - all_contents = galaxy_state["gi"].datasets.get_datasets( + all_contents = gi.datasets.get_datasets( history_id=history_id, order=order, ) @@ -740,12 +931,17 @@ def get_job_details(dataset_id: str, history_id: str | None = None) -> dict[str, Returns: Dictionary containing job metadata, tool information, dataset ID, and job ID """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] + base_url = state["url"] or normalized_galaxy_url or "" + api_key = state["api_key"] + if not base_url or not api_key: + raise ValueError("Galaxy connection is missing URL or API key information.") try: # Get dataset provenance to find the creating job try: - provenance = galaxy_state["gi"].histories.show_dataset_provenance( + provenance = gi.histories.show_dataset_provenance( history_id=history_id, dataset_id=dataset_id ) @@ -760,7 +956,7 @@ def get_job_details(dataset_id: str, history_id: str | None = None) -> dict[str, except Exception as provenance_error: # If provenance fails, try getting dataset details which might contain job info try: - dataset_details = galaxy_state["gi"].datasets.show_dataset(dataset_id) + dataset_details = gi.datasets.show_dataset(dataset_id) job_id = dataset_details.get("creating_job") if not job_id: raise ValueError( @@ -775,9 +971,9 @@ def get_job_details(dataset_id: str, history_id: str | None = None) -> dict[str, # Get job details using the Galaxy API directly # (Bioblend doesn't have a direct method for this) - url = f"{galaxy_state['url']}api/jobs/{job_id}" - headers = {"x-api-key": galaxy_state["api_key"]} - response = requests.get(url, headers=headers) + url = f"{base_url}api/jobs/{job_id}" + headers = {"x-api-key": api_key} + response = requests.get(url, headers=headers, timeout=30) response.raise_for_status() job_info = response.json() @@ -809,11 +1005,12 @@ def get_dataset_details( Dictionary containing dataset metadata (name, size, state, extension) and optional content preview with line count and truncation information """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get dataset details using bioblend - dataset_info = galaxy_state["gi"].datasets.show_dataset(dataset_id) + dataset_info = gi.datasets.show_dataset(dataset_id) result = {"dataset": dataset_info, "dataset_id": dataset_id} @@ -821,7 +1018,7 @@ def get_dataset_details( if include_preview and dataset_info.get("state") == "ok": try: # Get dataset content for preview - content = galaxy_state["gi"].datasets.download_dataset( + content = gi.datasets.download_dataset( dataset_id, use_default_filename=False, require_ok_state=False ) @@ -899,11 +1096,12 @@ def download_dataset( environments), omit the file_path parameter to download content to memory. Only specify file_path if you can actually write files to the local filesystem. """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get dataset info first to check state and get metadata - dataset_info = galaxy_state["gi"].datasets.show_dataset(dataset_id) + dataset_info = gi.datasets.show_dataset(dataset_id) # Check dataset state if required if require_ok_state and dataset_info.get("state") != "ok": @@ -915,7 +1113,7 @@ def download_dataset( # Download the dataset if file_path: # Download to specific path - result_path = galaxy_state["gi"].datasets.download_dataset( + result_path = gi.datasets.download_dataset( dataset_id, file_path=file_path, use_default_filename=False, @@ -930,7 +1128,7 @@ def download_dataset( else: # Download content to memory (don't save to filesystem) - result_path = galaxy_state["gi"].datasets.download_dataset( + result_path = gi.datasets.download_dataset( dataset_id, use_default_filename=False, # Get content in memory require_ok_state=require_ok_state, @@ -987,7 +1185,8 @@ def upload_file(path: str, history_id: str | None = None) -> dict[str, Any]: Returns: Dictionary containing upload status and information about the created dataset(s) """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: if not os.path.exists(path): @@ -997,7 +1196,7 @@ def upload_file(path: str, history_id: str | None = None) -> dict[str, Any]: "Check that the file exists and you have read permissions." ) - result = galaxy_state["gi"].tools.upload_file(path, history_id=history_id) + result = gi.tools.upload_file(path, history_id=history_id) return result except Exception as e: raise ValueError(f"Failed to upload file: {str(e)}") from e @@ -1031,16 +1230,17 @@ def get_invocations( Returns: Dictionary containing workflow invocation information, execution status, and step details """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # If invocation_id is provided, get details of a specific invocation if invocation_id: - invocation = galaxy_state["gi"].invocations.show_invocation(invocation_id) + invocation = gi.invocations.show_invocation(invocation_id) return {"invocation": invocation} # Otherwise get a list of invocations with optional filters - invocations = galaxy_state["gi"].invocations.get_invocations( + invocations = gi.invocations.get_invocations( workflow_id=workflow_id, history_id=history_id, limit=limit, @@ -1124,7 +1324,8 @@ def import_workflow_from_iwc(trs_id: str) -> dict[str, Any]: Returns: Imported workflow information """ - ensure_connected() + state = ensure_connected() + gi: GalaxyInstance = state["gi"] try: # Get the workflow manifest @@ -1154,11 +1355,54 @@ def import_workflow_from_iwc(trs_id: str) -> dict[str, Any]: ) # Import the workflow into Galaxy - imported_workflow = galaxy_state["gi"].workflows.import_workflow_dict(workflow_definition) + imported_workflow = gi.workflows.import_workflow_dict(workflow_definition) return {"imported_workflow": imported_workflow} except Exception as e: raise ValueError(f"Failed to import workflow from IWC: {str(e)}") from e +def run_http_server( + *, + host: str | None = None, + port: int | None = None, + transport: str | None = None, + path: str | None = None, +) -> None: + """Run the MCP server over HTTP-based transport.""" + resolved_host = host or os.environ.get("GALAXY_MCP_HOST", "0.0.0.0") + resolved_port = int(port or os.environ.get("GALAXY_MCP_PORT", "8000")) + resolved_transport = ( + transport or os.environ.get("GALAXY_MCP_TRANSPORT") or "streamable-http" + ).lower() + if resolved_transport not in {"streamable-http", "sse"}: + raise ValueError( + f"Unsupported transport '{resolved_transport}'. Choose 'streamable-http' or 'sse'." + ) + + resolved_path = path or os.environ.get("GALAXY_MCP_HTTP_PATH") + if resolved_path is None and resolved_transport == "streamable-http": + resolved_path = "/" + if resolved_path is not None and not resolved_path.startswith("/"): + resolved_path = f"/{resolved_path}" + + logger.info( + "Starting Galaxy MCP server over %s at %s:%s%s", + resolved_transport, + resolved_host, + resolved_port, + resolved_path or "", + ) + mcp.run( + transport=resolved_transport, + host=resolved_host, + port=resolved_port, + path=resolved_path, + ) + + if __name__ == "__main__": - mcp.run() + selected_transport = os.environ.get("GALAXY_MCP_TRANSPORT", "stdio").lower() + if selected_transport in {"streamable-http", "sse"}: + run_http_server(transport=selected_transport) + else: + mcp.run() diff --git a/mcp-server-galaxy-py/tests/README.md b/mcp-server-galaxy-py/tests/README.md index 2121faa..a1faffa 100644 --- a/mcp-server-galaxy-py/tests/README.md +++ b/mcp-server-galaxy-py/tests/README.md @@ -11,6 +11,7 @@ This directory contains integration tests for the Galaxy MCP server. - `test_tool_operations.py` - Tests for tool search and execution - `test_workflow_operations.py` - Tests for workflow operations - `test_integration.py` - End-to-end integration tests +- `test_oauth.py` - Tests for OAuth session handling and public endpoints ## Running Tests diff --git a/mcp-server-galaxy-py/tests/TEST_STRATEGY.md b/mcp-server-galaxy-py/tests/TEST_STRATEGY.md index d53deaf..fb77aa7 100644 --- a/mcp-server-galaxy-py/tests/TEST_STRATEGY.md +++ b/mcp-server-galaxy-py/tests/TEST_STRATEGY.md @@ -31,6 +31,7 @@ tests/ ├── test_dataset_operations.py # Dataset upload/download ├── test_tool_operations.py # Tool search and execution ├── test_workflow_operations.py # Workflow import and invocation +├── test_oauth.py # OAuth flow and HTTP public routes └── test_integration.py # End-to-end scenarios ``` diff --git a/mcp-server-galaxy-py/tests/test_connection.py b/mcp-server-galaxy-py/tests/test_connection.py index 978e2b7..d4238f9 100644 --- a/mcp-server-galaxy-py/tests/test_connection.py +++ b/mcp-server-galaxy-py/tests/test_connection.py @@ -5,8 +5,9 @@ from unittest.mock import patch import pytest +from galaxy_mcp.auth import GalaxyCredentials -from .test_helpers import ensure_connected, galaxy_state, get_server_info_fn +from .test_helpers import connect_fn, ensure_connected, galaxy_state, get_server_info_fn @pytest.mark.usefixtures("_test_env") @@ -59,6 +60,39 @@ def test_connection_with_missing_credentials(self): # Without credentials, should not connect assert not galaxy_state.get("connected", False) + def test_connect_returns_oauth_session(self, mock_galaxy_instance): + """Ensure connect() reports OAuth session details when available.""" + credentials = GalaxyCredentials( + galaxy_url="https://oauth.galaxy/", + api_key="oauth-api-key", + username="oauth-user", + user_email="oauth@example.com", + expires_at=1_700_000_000, + scopes=["galaxy:full"], + client_id="client-123", + ) + + with patch("galaxy_mcp.server.auth_provider", object()): + with patch( + "galaxy_mcp.server.get_active_session", + return_value=(credentials, credentials.api_key), + ): + with patch( + "galaxy_mcp.server.GalaxyInstance", return_value=mock_galaxy_instance + ) as mock_constructor: + result = connect_fn() + + assert result["connected"] is True + assert result["auth"] == "oauth" + assert result["url"] == credentials.galaxy_url + assert ( + result["user"]["username"] + == mock_galaxy_instance.users.get_current_user.return_value["username"] + ) + mock_constructor.assert_called_once_with( + url=credentials.galaxy_url, key=credentials.api_key + ) + def test_get_server_info_success(self, mock_galaxy_instance): """Test successful server info retrieval""" # Mock server config and version responses diff --git a/mcp-server-galaxy-py/tests/test_oauth.py b/mcp-server-galaxy-py/tests/test_oauth.py new file mode 100644 index 0000000..06044aa --- /dev/null +++ b/mcp-server-galaxy-py/tests/test_oauth.py @@ -0,0 +1,100 @@ +"""Tests for OAuth-aware functionality.""" + +from unittest.mock import patch + +import pytest +from galaxy_mcp.auth import GalaxyCredentials +from galaxy_mcp.server import _OAuthPublicRoutes, ensure_connected +from starlette.applications import Starlette +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.testclient import TestClient + + +@pytest.mark.usefixtures("_test_env") +def test_ensure_connected_prefers_oauth_session(mock_galaxy_instance): + """ensure_connected should build a Galaxy client from the active OAuth session.""" + credentials = GalaxyCredentials( + galaxy_url="https://oauth.galaxy/", + api_key="oauth-api-key", + username="oauth-user", + user_email="oauth@example.com", + expires_at=1_700_000_000, + scopes=["galaxy:full"], + client_id="client-123", + ) + + with patch("galaxy_mcp.server.auth_provider", object()): + with patch( + "galaxy_mcp.server.get_active_session", + return_value=(credentials, credentials.api_key), + ): + with patch( + "galaxy_mcp.server.GalaxyInstance", return_value=mock_galaxy_instance + ) as mock_constructor: + state = ensure_connected() + + assert state["source"] == "oauth" + assert state["connected"] is True + assert state["gi"] is mock_galaxy_instance + assert state["url"] == credentials.galaxy_url + mock_constructor.assert_called_once_with(url=credentials.galaxy_url, key=credentials.api_key) + + +class _DummyProvider: + def __init__(self): + self.login_calls = 0 + self.metadata_calls = 0 + + def get_login_paths(self, base_path=None): + return {"/galaxy-auth/login"} + + def get_resource_metadata_paths(self, base_path=None): + return {"/.well-known/oauth-protected-resource"} + + async def handle_login(self, request): + self.login_calls += 1 + return PlainTextResponse("login-ok") + + async def handle_resource_metadata(self, request): + self.metadata_calls += 1 + return JSONResponse({"resource": "https://oauth.galaxy/"}) + + +def test_oauth_public_routes_expose_login_and_metadata(): + """_OAuthPublicRoutes should short-circuit auth for public endpoints.""" + + app = Starlette() + + async def ping(_request): # pragma: no cover - trivial handler + return PlainTextResponse("pong") + + app.add_route("/ping", ping, methods=["GET"]) + + provider = _DummyProvider() + wrapped = _OAuthPublicRoutes(app, provider, base_path=None) + + # The wrapper should proxy common Starlette attributes + assert wrapped.state is app.state + assert wrapped.router is app.router + + client = TestClient(wrapped) + + # Public endpoints are handled by the provider + login_response = client.get("/galaxy-auth/login") + assert login_response.status_code == 200 + assert login_response.text == "login-ok" + assert provider.login_calls == 1 + + metadata_response = client.get("/.well-known/oauth-protected-resource") + assert metadata_response.status_code == 200 + assert metadata_response.json() == {"resource": "https://oauth.galaxy/"} + assert provider.metadata_calls == 1 + + # Non-public endpoints fall back to the underlying application + ping_response = client.get("/ping") + assert ping_response.status_code == 200 + assert ping_response.text == "pong" + + # Unsupported methods should flow through to Starlette and return 404 + fallback_response = client.post("/.well-known/oauth-protected-resource") + assert fallback_response.status_code == 404 diff --git a/mcp-server-galaxy-py/uv.lock b/mcp-server-galaxy-py/uv.lock index 4fc49cf..a88b24c 100644 --- a/mcp-server-galaxy-py/uv.lock +++ b/mcp-server-galaxy-py/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" [[package]] @@ -654,10 +654,12 @@ wheels = [ [[package]] name = "galaxy-mcp" -version = "1.0.0" +version = "1.1.0" source = { editable = "." } dependencies = [ + { name = "anyio" }, { name = "bioblend" }, + { name = "cryptography" }, { name = "fastmcp" }, { name = "python-dotenv" }, { name = "requests" }, @@ -683,8 +685,10 @@ dev = [ [package.metadata] requires-dist = [ + { name = "anyio", specifier = ">=4.0.0" }, { name = "bioblend", specifier = ">=1.5.0" }, { name = "build", marker = "extra == 'dev'", specifier = ">=1.2.1" }, + { name = "cryptography", specifier = ">=41.0.0" }, { name = "fastmcp", specifier = ">=2.3.0" }, { name = "httpx", marker = "extra == 'dev'", specifier = ">=0.27.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },