diff --git a/README.md b/README.md index 38d5becaf..311918abb 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,26 @@ ToolHive is available as a GUI desktop app, CLI, and Kubernetes Operator. +## Kubernetes Operator + +ToolHive includes a Kubernetes Operator for enterprise and production deployments: + +### Features + +- **MCPServer CRD**: Deploy and manage MCP servers as Kubernetes resources +- **MCPRegistry CRD** *(Experimental)*: Centralized registry management with automated sync +- **Secure isolation**: Container-based server execution with permission profiles +- **Protocol proxying**: Stdio servers exposed via HTTP/SSE networking protocols +- **Service discovery**: Automatic service creation and DNS integration + +### Documentation + +- [Operator Guide](cmd/thv-operator/README.md) - Complete operator documentation +- [MCPRegistry Reference](cmd/thv-operator/REGISTRY.md) - Registry management (experimental) +- [CRD API Reference](docs/operator/crd-api.md) - Auto-generated API documentation +- [Deployment Guide](docs/kind/deploying-toolhive-operator.md) - Step-by-step installation +- [Examples](examples/operator/) - Sample configurations + ## Quick links - 📚 [Documentation](https://docs.stacklok.com/toolhive/) diff --git a/cmd/thv-operator/CLAUDE.md b/cmd/thv-operator/CLAUDE.md index 40d789c84..550c1777d 100644 --- a/cmd/thv-operator/CLAUDE.md +++ b/cmd/thv-operator/CLAUDE.md @@ -13,6 +13,62 @@ After modifying the CRDs, the following needs to be run: When committing a change that changes CRDs, it is important to bump the chart version as described in the [CLAUDE.md](../../deploy/charts/operator-crds/CLAUDE.md#bumping-crd-chart) doc for the CRD Helm Chart. +## MCPRegistry CRD (Experimental) + +The MCPRegistry CRD enables centralized management of MCP server registries. Requires `operator.features.experimental=true`. + +### Key Components + +- **CRD**: `api/v1alpha1/mcpregistry_types.go` +- **Controller**: `controllers/mcpregistry_controller.go` +- **Status**: `pkg/mcpregistrystatus/` +- **Sync**: `pkg/sync/` +- **Sources**: `pkg/sources/` +- **API**: `pkg/registryapi/` + +### Development Patterns + +#### Status Collector Pattern + +Always use StatusCollector for batched updates: + +```go +// ✅ Good: Collect all changes, apply once +statusCollector := mcpregistrystatus.NewCollector(mcpRegistry) +statusCollector.SetPhase(mcpv1alpha1.MCPRegistryPhaseReady) +statusCollector.Apply(ctx, r.Client) + +// ❌ Bad: Multiple individual updates cause conflicts +r.Status().Update(ctx, mcpRegistry) +``` + +#### Error Handling + +Always set status before returning errors: + +```go +if err := validateSource(); err != nil { + statusCollector.SetSyncStatus(mcpv1alpha1.SyncPhaseFailed, err.Error(), ...) + return ctrl.Result{RequeueAfter: time.Minute * 5}, err +} +``` + +#### Source Handler Interface + +```go +type SourceHandler interface { + FetchRegistryData(ctx context.Context, source MCPRegistrySource) (*RegistryData, error) + ValidateSource(ctx context.Context, source MCPRegistrySource) error + CalculateHash(ctx context.Context, source MCPRegistrySource) (string, error) +} +``` + +### Testing Patterns + +- **Unit Tests**: Use mocks for external dependencies +- **Integration Tests**: Use envtest framework +- **E2E Tests**: Missing for MCPRegistry (use Chainsaw) + ## OpenTelemetry (OTEL) Stack for Testing When you have been asked to stand up an OTEL stack to test ToolHives integration inside of Kubernetes, you will need to perform the following tasks inside of the cluster that you have been instructed to use. diff --git a/cmd/thv-operator/DESIGN.md b/cmd/thv-operator/DESIGN.md index e58312e87..8f2ac745d 100644 --- a/cmd/thv-operator/DESIGN.md +++ b/cmd/thv-operator/DESIGN.md @@ -1,44 +1,111 @@ # Design & Decisions -This document aims to help fill in gaps of any decision that are made around the design of the ToolHive Operator. +This document captures architectural decisions and design patterns for the ToolHive Operator. -## CRD Attribute vs `PodTemplateSpec` +## Operator Design Principles + +### CRD Attribute vs `PodTemplateSpec` When building operators, the decision of when to use a `podTemplateSpec` and when to use a CRD attribute is always disputed. For the ToolHive Operator we have a defined rule of thumb. -### Use Dedicated CRD Attributes For: +#### Use Dedicated CRD Attributes For: - **Business logic** that affects your operator's behavior -- **Validation requirements** (ranges, formats, constraints) +- **Validation requirements** (ranges, formats, constraints) - **Cross-resource coordination** (affects Services, ConfigMaps, etc.) - **Operator decision making** (triggers different reconciliation paths) -```yaml -spec: - version: "13.4" # Affects operator logic - replicas: 3 # Affects scaling behavior - backupSchedule: "0 2 * * *" # Needs validation -``` - -### Use PodTemplateSpec For: +#### Use PodTemplateSpec For: - **Infrastructure concerns** (node selection, resources, affinity) -- **Sidecar containers** +- **Sidecar containers** - **Standard Kubernetes pod configuration** - **Things a cluster admin would typically configure** -```yaml -spec: - podTemplate: - spec: - nodeSelector: - disktype: ssd - containers: - - name: sidecar - image: monitoring:latest -``` - -## Quick Decision Test: +#### Quick Decision Test: 1. **"Does this affect my operator's reconciliation logic?"** -> Dedicated attribute -2. **"Is this standard Kubernetes pod configuration?"** -> PodTemplateSpec +2. **"Is this standard Kubernetes pod configuration?"** -> PodTemplateSpec 3. **"Do I need to validate this beyond basic Kubernetes validation?"** -> Dedicated attribute -This gives you a clean API for core functionality while maintaining flexibility for infrastructure concerns. \ No newline at end of file +## MCPRegistry Architecture Decisions + +### Status Management Design + +**Decision**: Use batched status updates via StatusCollector pattern instead of individual field updates. + +**Rationale**: +- Prevents race conditions between multiple status updates +- Reduces API server load with fewer update calls +- Ensures consistent status across reconciliation cycles +- Handles resource version conflicts gracefully + +**Implementation**: StatusCollector interface collects all changes and applies them atomically. + +### Sync Operation Design + +**Decision**: Separate sync decision logic from sync execution with clear interfaces. + +**Rationale**: +- Testability: Mock sync decisions independently from execution +- Flexibility: Different sync strategies without changing core logic +- Maintainability: Clear separation of concerns + +**Key Patterns**: +- Idempotent operations for safe retry +- Manual vs automatic sync distinction +- Data preservation on failures + +### Storage Architecture + +**Decision**: Abstract storage via StorageManager interface with ConfigMap as default implementation. + +**Rationale**: +- Future flexibility: Easy addition of new storage backends (OCI, databases) +- Testability: Mock storage for unit tests +- Consistency: Single interface for all storage operations + +**Current Implementation**: ConfigMap-based with owner references for automatic cleanup. + +### Registry API Service Pattern + +**Decision**: Deploy individual API service per MCPRegistry rather than shared service. + +**Rationale**: +- **Isolation**: Each registry has independent lifecycle and scaling +- **Security**: Per-registry access control possible +- **Reliability**: Failure of one registry doesn't affect others +- **Lifecycle Management**: Automatic cleanup via owner references + +**Trade-offs**: More resources consumed but better isolation and security. + +### Error Handling Strategy + +**Decision**: Structured error types with progressive retry backoff. + +**Rationale**: +- Different error types need different handling strategies +- Progressive backoff prevents thundering herd problems +- Structured errors enable better observability + +**Implementation**: 5m initial retry, exponential backoff with cap, manual sync bypass. + +### Performance Design Decisions + +#### Resource Optimization +- **Status Updates**: Batched to reduce API calls (implemented) +- **Source Fetching**: Planned caching to avoid repeated downloads +- **API Deployment**: Lazy creation only when needed (implemented) + +#### Memory Management +- **Git Operations**: Shallow clones to minimize disk usage (implemented) +- **Large Registries**: Stream processing planned for future +- **Status Objects**: Efficient field-level updates (implemented) + +### Security Architecture + +#### Permission Model +Minimal required permissions following principle of least privilege: +- ConfigMaps: For storage management +- Services/Deployments: For API service management +- MCPRegistry: For status updates + +#### Network Security +Optional network policies for registry API access control in security-sensitive environments. diff --git a/cmd/thv-operator/README.md b/cmd/thv-operator/README.md index a49c5a3fe..deb2d181c 100644 --- a/cmd/thv-operator/README.md +++ b/cmd/thv-operator/README.md @@ -1,18 +1,34 @@ # ToolHive Kubernetes Operator -The ToolHive Kubernetes Operator manages MCP (Model Context Protocol) servers in Kubernetes clusters. It allows you to define MCP servers as Kubernetes resources and automates their deployment and management. +The ToolHive Kubernetes Operator manages MCP (Model Context Protocol) servers and registries in Kubernetes clusters. It allows you to define MCP servers and registries as Kubernetes resources and automates their deployment and management. This operator is built using [Kubebuilder](https://book.kubebuilder.io/), a framework for building Kubernetes APIs using Custom Resource Definitions (CRDs). ## Overview -The operator introduces a new Custom Resource Definition (CRD) called `MCPServer` that represents an MCP server in Kubernetes. When you create an `MCPServer` resource, the operator automatically: +The operator introduces two main Custom Resource Definitions (CRDs): + +### MCPServer +Represents an MCP server in Kubernetes. When you create an `MCPServer` resource, the operator automatically: 1. Creates a Deployment to run the MCP server 2. Sets up a Service to expose the MCP server 3. Configures the appropriate permissions and settings 4. Manages the lifecycle of the MCP server +### MCPRegistry (Experimental) + +> ⚠️ **Experimental Feature**: MCPRegistry requires `ENABLE_EXPERIMENTAL_FEATURES=true` + +Represents an MCP server registry in Kubernetes. When you create an `MCPRegistry` resource, the operator automatically: + +1. Synchronizes registry data from various sources (ConfigMap, Git) +2. Deploys a Registry API service for server discovery +3. Provides content filtering and image validation +4. Manages automatic and manual synchronization policies + +For detailed MCPRegistry documentation, see [REGISTRY.md](REGISTRY.md). + ```mermaid --- config: @@ -107,7 +123,11 @@ helm upgrade -i toolhive-operator-crds oci://ghcr.io/stacklok/toolhive/toolhive- 2. Install the operator: ```bash +# Standard installation helm upgrade -i oci://ghcr.io/stacklok/toolhive/toolhive-operator --version= -n toolhive-system --create-namespace + +# OR with experimental features (for MCPRegistry support) +helm upgrade -i oci://ghcr.io/stacklok/toolhive/toolhive-operator --version= -n toolhive-system --create-namespace --set operator.features.experimental=true ``` ## Usage @@ -236,9 +256,49 @@ permissionProfile: The ConfigMap should contain a JSON permission profile. +### Creating an MCP Registry (Experimental) + +> ⚠️ **Requires**: `operator.features.experimental=true` + +First, create a ConfigMap containing ToolHive registry data. The ConfigMap must be user-defined and is not managed by the operator: + +```bash +# Create ConfigMap from existing registry data +kubectl create configmap my-registry-data --from-file registry.json=pkg/registry/data/registry.json -n toolhive-system + +# Or create from your own registry file +kubectl create configmap my-registry-data --from-file registry.json=/path/to/your/registry.json -n toolhive-system +``` + +Then create the MCPRegistry resource that references the ConfigMap: + +```yaml +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPRegistry +metadata: + name: my-registry + namespace: toolhive-system +spec: + displayName: "My MCP Registry" + source: + type: configmap + configmap: + name: my-registry-data # References the user-created ConfigMap + key: registry.json # Key in ConfigMap (default: "registry.json") + syncPolicy: + interval: "1h" + filter: + tags: + include: ["production"] + exclude: ["experimental"] +``` + +For complete MCPRegistry examples and documentation, see [REGISTRY.md](REGISTRY.md). + ## Examples -See the `examples/operator/mcp-servers/` directory for example MCPServer resources. +- **MCPServer examples**: `examples/operator/mcp-servers/` directory +- **MCPRegistry examples**: `examples/operator/mcp-registries/` directory ## Development diff --git a/cmd/thv-operator/REGISTRY.md b/cmd/thv-operator/REGISTRY.md new file mode 100644 index 000000000..cc0d6186b --- /dev/null +++ b/cmd/thv-operator/REGISTRY.md @@ -0,0 +1,490 @@ +# MCPRegistry Reference + +> ⚠️ **Experimental Feature**: MCPRegistry requires enabling experimental features with `--set operator.features.experimental=true` during Helm installation. + +## Overview + +MCPRegistry is a Kubernetes Custom Resource that manages MCP (Model Context Protocol) server registries. It provides centralized server discovery, automated synchronization, content filtering, and image validation for MCP servers in your cluster. + +## Quick Start + +Create a basic registry from a ConfigMap: + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: my-registry-data + namespace: toolhive-system +data: + registry.json: | + { + "$schema": "https://raw.githubusercontent.com/stacklok/toolhive/main/pkg/registry/data/schema.json", + "version": "1.0.0", + "last_updated": "2025-01-14T00:00:00Z", + "servers": { + "github": { + "description": "GitHub API integration", + "tier": "Official", + "status": "Active", + "transport": "stdio", + "tools": ["create_issue", "search_repositories"], + "image": "ghcr.io/github/github-mcp-server:latest", + "tags": ["github", "api", "production"] + } + } + } +--- +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPRegistry +metadata: + name: my-registry + namespace: toolhive-system +spec: + displayName: "My MCP Registry" + source: + type: configmap + configmap: + name: my-registry-data + key: registry.json +``` + +Apply with: +```bash +kubectl apply -f my-registry.yaml +``` + +## Sync Operations + +### Automatic Sync + +Configure automatic synchronization with interval-based policies: + +```yaml +spec: + syncPolicy: + interval: "1h" # Sync every hour +``` + +Supported intervals: +- `30s`, `5m`, `1h`, `24h` +- Any valid Go duration format + +### Manual Sync + +Trigger manual sync using annotations: + +```bash +kubectl annotate mcpregistry my-registry toolhive.stacklok.dev/manual-sync="$(date +%s)" +``` + +Or in YAML: +```yaml +metadata: + annotations: + toolhive.stacklok.dev/manual-sync: "1704110400" +``` + +### Sync Status + +Check sync status: +```bash +kubectl get mcpregistry my-registry -o jsonpath='{.status.syncStatus}' +``` + +Status phases: +- `Idle`: No sync needed +- `Syncing`: Sync in progress +- `Complete`: Sync completed successfully +- `Failed`: Sync failed (check `.status.syncStatus.message`) + +## Data Sources + +### ConfigMap Source + +Store registry data in Kubernetes ConfigMaps: + +```yaml +spec: + source: + type: configmap + format: toolhive # or "upstream" + configmap: + name: registry-data + key: registry.json # optional, defaults to "registry.json" +``` + +### Git Source + +Synchronize from Git repositories: + +```yaml +spec: + source: + type: git + format: toolhive + git: + repository: "https://github.com/org/mcp-registry" + branch: "main" + path: "registry.json" # optional, defaults to "registry.json" +``` + +Supported repository URL formats: +- `https://github.com/org/repo` - HTTPS (recommended) +- `git@github.com:org/repo.git` - SSH +- `ssh://git@example.com/repo.git` - SSH with explicit protocol +- `git://example.com/repo.git` - Git protocol +- `file:///path/to/local/repo` - Local filesystem (for testing) + +### Registry Formats + +**ToolHive Format** (default): +- Native ToolHive registry schema +- Supports all ToolHive features +- See [registry schema](../../pkg/registry/data/schema.json) + +**Upstream Format**: +- Standard MCP registry format +- Compatible with community registries +- Automatically converted to ToolHive format +- **Note**: Not supported until the upstream schema is more stable + +## Content Filtering + +### Tag-Based Filtering + +Filter servers by tags: + +```yaml +spec: + filter: + tags: + include: + - "production" + - "database" + exclude: + - "experimental" + - "deprecated" +``` + +### Name-Based Filtering + +Filter servers by name patterns: + +```yaml +spec: + filter: + names: + include: + - "github*" # Include github-* servers + - "*-prod" # Include *-prod servers + exclude: + - "*-beta" # Exclude beta servers + - "test-*" # Exclude test servers +``` + +### Filter Precedence + +1. **Include filters** are applied first (if specified) +2. **Exclude filters** are applied second +3. Empty include list means "include all" +4. Exclusions always take precedence over inclusions + +Example behavior: +```yaml +filter: + tags: + include: ["database", "production"] + exclude: ["experimental"] +# Result: Include database AND production servers, but exclude any experimental ones +``` + +### Automatic Filter Change Detection + +The operator automatically detects when filters are modified and triggers a resync: +- Filter changes are detected using SHA256 hash comparison +- No manual intervention required when updating filter configuration +- Changes are tracked in the `status.lastAppliedFilterHash` field + +## Image Validation + +### Registry-Based Enforcement + +Enforce that MCPServer images must be present in at least one registry: + +```yaml +spec: + enforceServers: true +``` + +When enabled: +- MCPServers in the namespace are validated against registry content +- Only images present in any registry with `enforceServers: true` are allowed +- MCPServers are matched to registry entries by the `server-registry-name` label +- Invalid images cause MCPServer creation to fail + +### MCPServer Matching + +MCPServers are matched to registry entries using the `server-registry-name` label: + +```yaml +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPServer +metadata: + name: github-server + labels: + server-registry-name: "github" # Must match registry entry name +spec: + image: ghcr.io/github/github-mcp-server:latest +``` + +### Validation Workflow + +1. MCPServer is created/updated in namespace +2. Operator checks if any registry in namespace has `enforceServers: true` +3. If yes, validates that the MCPServer's image matches a registry entry +4. Registry matching is done by `server-registry-name` label +5. Allows or rejects based on validation result + +### Error Handling + +**Note**: Current implementation does not emit Kubernetes events for validation failures. Error details are available in operator logs: + +```bash +# Check operator logs for validation errors +kubectl logs -n toolhive-system deployment/toolhive-operator | grep validation +``` + +## Registry API Service + +Each MCPRegistry automatically deploys an API service for registry access: + +### API Endpoints + +**Registry Data APIs:** +- `GET /api/v1/registry/servers` - List all servers from registry +- `GET /api/v1/registry/servers/{name}` - Get specific server from registry +- `GET /api/v1/registry/info` - Get registry metadata + +**Deployed Server APIs** (ToolHive proprietary): +- `GET /api/v1/registry/servers/deployed` - List all deployed MCPServer instances +- `GET /api/v1/registry/servers/deployed/{name}` - Get deployed servers matching registry name + +**System APIs:** +- `GET /health` - Health check +- `GET /readiness` - Readiness check +- `GET /version` - Version information +- `GET /api/v1/registry/openapi.yaml` - OpenAPI specification + +**Note**: For compatibility with upstream MCP registry APIs, see [MCP Registry Protocol](https://modelcontextprotocol.io/registry) specification. + +### Service Access + +Internal cluster access: +``` +http://{registry-name}-api.{namespace}.svc.cluster.local:8080 +``` + +Port forward for external access: +```bash +kubectl port-forward svc/my-registry-api 8080:8080 +curl http://localhost:8080/servers +``` + +### API Status + +Check API deployment status: +```bash +kubectl get mcpregistry my-registry -o jsonpath='{.status.apiStatus}' +``` + +API phases: +- `Deploying`: API deployment in progress +- `Ready`: API service is available +- `Error`: API deployment failed + +## Status Management + +### Overall Status + +MCPRegistry phase indicates overall state: + +```bash +kubectl get mcpregistry +NAME PHASE MESSAGE +my-registry Ready Registry is ready and API is serving requests +``` + +Phases: +- `Pending`: Initialization in progress +- `Syncing`: Data synchronization active +- `Ready`: Fully operational +- `Failed`: Operation failed +- `Terminating`: Being deleted + +### Detailed Status + +```yaml +status: + phase: Ready + message: "Registry is ready and API is serving requests" + syncStatus: + phase: Complete + message: "Registry data synchronized successfully" + serverCount: 5 + lastSyncTime: "2025-01-14T10:30:00Z" + lastSyncHash: "abc123" + apiStatus: + phase: Ready + endpoint: "http://my-registry-api.toolhive-system.svc.cluster.local:8080" + readySince: "2025-01-14T10:25:00Z" + lastAppliedFilterHash: "def456" + storageRef: + type: configmap + configMapRef: + name: "my-registry-registry-storage" + lastManualSyncTrigger: "1704110400" + conditions: + - type: SyncSuccessful + status: "True" + reason: SyncComplete + - type: APIReady + status: "True" + reason: DeploymentReady +``` + +## Security Best Practices + +### Access Control + +1. **Namespace Isolation**: Deploy registries in dedicated namespaces +2. **RBAC**: Limit registry modification permissions +3. **Service Accounts**: Use dedicated service accounts for registry operations + +### Secret Management + +**Note**: Secret management for Git authentication is planned but not yet implemented. Currently, only public repositories are supported for Git sources. + +### Image Security + +1. **Enable enforcement**: Use `enforceServers: true` to validate images +2. **Registry trust**: Only include trusted registries +3. **Regular updates**: Keep registry data current with security patches + +## Troubleshooting + +### Common Issues + +**Sync Failures**: +```bash +# Check sync status +kubectl get mcpregistry my-registry -o jsonpath='{.status.syncStatus.message}' + +# Common causes: +# - Invalid ConfigMap/Git source +# - Network connectivity issues +# - Malformed registry data +``` + +**API Not Ready**: +```bash +# Check API status +kubectl get mcpregistry my-registry -o jsonpath='{.status.apiStatus}' + +# Check deployment +kubectl get deployment my-registry-api + +# Common causes: +# - Resource constraints +# - Image pull failures +# - Configuration errors +``` + +**Image Validation Errors**: +```bash +# Check MCPServer events +kubectl describe mcpserver problematic-server + +# Common causes: +# - Image not in registry +# - Registry not synced +# - Typo in image name +``` + +### Debug Commands + +```bash +# View registry events +kubectl get events --field-selector involvedObject.kind=MCPRegistry + +# Check operator logs +kubectl logs -n toolhive-system deployment/toolhive-operator + +# Describe registry for detailed status +kubectl describe mcpregistry my-registry + +# Manual sync trigger +kubectl annotate mcpregistry my-registry toolhive.stacklok.dev/manual-sync="$(date +%s)" +``` + +### Log Analysis + +Operator logs show: +- Sync operations and results +- API deployment status +- Image validation attempts +- Error details with context + +Filter for specific registry: +```bash +kubectl logs -n toolhive-system deployment/toolhive-operator | grep "my-registry" +``` + +## Examples + +### Production Registry with Filtering +```yaml +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPRegistry +metadata: + name: production-registry +spec: + displayName: "Production MCP Servers" + source: + type: configmap + configmap: + name: prod-registry-data + syncPolicy: + interval: "1h" + filter: + tags: + include: ["production"] + exclude: ["experimental", "deprecated"] + enforceServers: true +``` + +### Development Registry +```yaml +apiVersion: toolhive.stacklok.dev/v1alpha1 +kind: MCPRegistry +metadata: + name: dev-registry +spec: + displayName: "Development MCP Servers" + source: + type: git + git: + repository: "https://github.com/org/dev-mcp-registry" + branch: "develop" + # No sync policy = manual sync only + filter: + names: + include: ["dev-*", "*-test"] +``` + +## See Also + +- [MCPServer Documentation](README.md#usage) +- [Operator Installation](../../docs/kind/deploying-toolhive-operator.md) +- [Registry Examples](../../examples/operator/mcp-registries/) +- [Registry Schema](../../pkg/registry/data/schema.json) \ No newline at end of file diff --git a/go.mod b/go.mod index dfc94ad31..14c8805b9 100644 --- a/go.mod +++ b/go.mod @@ -44,15 +44,15 @@ require ( go.uber.org/zap v1.27.0 golang.ngrok.com/ngrok/v2 v2.1.0 golang.org/x/exp/jsonrpc2 v0.0.0-20251002181428-27f1f14c8bb9 - golang.org/x/mod v0.28.0 - golang.org/x/oauth2 v0.31.0 + golang.org/x/mod v0.29.0 + golang.org/x/oauth2 v0.32.0 golang.org/x/sync v0.17.0 - golang.org/x/term v0.35.0 + golang.org/x/term v0.36.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.34.1 k8s.io/apimachinery v0.34.1 k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 - sigs.k8s.io/controller-runtime v0.22.1 + sigs.k8s.io/controller-runtime v0.22.2 sigs.k8s.io/yaml v1.6.0 ) @@ -263,10 +263,10 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.ngrok.com/muxado/v2 v2.0.1 // indirect golang.org/x/exp/event v0.0.0-20250819193227-8b4c13bb791b // indirect - golang.org/x/net v0.43.0 // indirect + golang.org/x/net v0.44.0 // indirect golang.org/x/text v0.29.0 // indirect golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.36.0 // indirect + golang.org/x/tools v0.37.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/api v0.248.0 // indirect @@ -280,7 +280,7 @@ require ( gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gotest.tools/v3 v3.5.2 // indirect - k8s.io/apiextensions-apiserver v0.34.0 // indirect + k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b // indirect sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect @@ -320,6 +320,6 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 golang.org/x/crypto v0.42.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/sys v0.36.0 + golang.org/x/sys v0.37.0 k8s.io/client-go v0.34.1 ) diff --git a/go.sum b/go.sum index e0cb2d1c4..35cc42a60 100644 --- a/go.sum +++ b/go.sum @@ -1840,8 +1840,8 @@ golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= -golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1904,8 +1904,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= +golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1935,8 +1935,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= -golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= +golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -2044,8 +2044,10 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20250908211612-aef8a434d053 h1:dHQOQddU4YHS5gY33/6klKjq7Gp3WwMyOXGNp5nzRj8= +golang.org/x/telemetry v0.0.0-20250908211612-aef8a434d053/go.mod h1:+nZKN+XVh4LCiA9DV3ywrzN4gumyCnKjau3NGb9SGoE= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -2058,8 +2060,8 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= -golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -2157,8 +2159,8 @@ golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= +golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -2483,8 +2485,8 @@ honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= k8s.io/api v0.34.1 h1:jC+153630BMdlFukegoEL8E/yT7aLyQkIVuwhmwDgJM= k8s.io/api v0.34.1/go.mod h1:SB80FxFtXn5/gwzCoN6QCtPD7Vbu5w2n1S0J5gFfTYk= -k8s.io/apiextensions-apiserver v0.34.0 h1:B3hiB32jV7BcyKcMU5fDaDxk882YrJ1KU+ZSkA9Qxoc= -k8s.io/apiextensions-apiserver v0.34.0/go.mod h1:hLI4GxE1BDBy9adJKxUxCEHBGZtGfIg98Q+JmTD7+g0= +k8s.io/apiextensions-apiserver v0.34.1 h1:NNPBva8FNAPt1iSVwIE0FsdrVriRXMsaWFMqJbII2CI= +k8s.io/apiextensions-apiserver v0.34.1/go.mod h1:hP9Rld3zF5Ay2Of3BeEpLAToP+l4s5UlxiHfqRaRcMc= k8s.io/apimachinery v0.34.1 h1:dTlxFls/eikpJxmAC7MVE8oOeP1zryV7iRyIjB0gky4= k8s.io/apimachinery v0.34.1/go.mod h1:/GwIlEcWuTX9zKIg2mbw0LRFIsXwrfoVxn+ef0X13lw= k8s.io/client-go v0.34.1 h1:ZUPJKgXsnKwVwmKKdPfw4tB58+7/Ik3CrjOEhsiZ7mY= @@ -2533,8 +2535,8 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8 rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= -sigs.k8s.io/controller-runtime v0.22.1 h1:Ah1T7I+0A7ize291nJZdS1CabF/lB4E++WizgV24Eqg= -sigs.k8s.io/controller-runtime v0.22.1/go.mod h1:FwiwRjkRPbiN+zp2QRp7wlTCzbUXxZ/D4OzuQUDwBHY= +sigs.k8s.io/controller-runtime v0.22.2 h1:cK2l8BGWsSWkXz09tcS4rJh95iOLney5eawcK5A33r4= +sigs.k8s.io/controller-runtime v0.22.2/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7npUVDpVXUUOFfW6CgAqE= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= diff --git a/pkg/auth/tokenexchange/exchange.go b/pkg/auth/tokenexchange/exchange.go index 9eafbe3b5..4786a6db8 100644 --- a/pkg/auth/tokenexchange/exchange.go +++ b/pkg/auth/tokenexchange/exchange.go @@ -158,8 +158,8 @@ func (c clientAuthentication) String() string { c.ClientID, clientSecret) } -// Config holds the configuration for token exchange. -type Config struct { +// ExchangeConfig holds the configuration for token exchange. +type ExchangeConfig struct { // TokenURL is the OAuth 2.0 token endpoint URL TokenURL string @@ -185,8 +185,8 @@ type Config struct { HTTPClient *http.Client } -// Validate checks if the Config contains all required fields. -func (c *Config) Validate() error { +// Validate checks if the ExchangeConfig contains all required fields. +func (c *ExchangeConfig) Validate() error { if c.TokenURL == "" { return fmt.Errorf("TokenURL is required") } @@ -211,7 +211,7 @@ func (c *Config) Validate() error { // tokenSource implements oauth2.TokenSource for token exchange. type tokenSource struct { ctx context.Context - conf *Config + conf *ExchangeConfig } // Token implements oauth2.TokenSource interface. @@ -281,7 +281,7 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) { } // TokenSource returns an oauth2.TokenSource that performs token exchange. -func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { +func (c *ExchangeConfig) TokenSource(ctx context.Context) oauth2.TokenSource { return &tokenSource{ ctx: ctx, conf: c, diff --git a/pkg/auth/tokenexchange/exchange_test.go b/pkg/auth/tokenexchange/exchange_test.go index 54bf3ac49..655bd0ae5 100644 --- a/pkg/auth/tokenexchange/exchange_test.go +++ b/pkg/auth/tokenexchange/exchange_test.go @@ -125,7 +125,7 @@ func TestTokenSource_Token_Success(t *testing.T) { defer server.Close() // Create config with test server - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -166,7 +166,7 @@ func TestTokenSource_Token_WithRefreshToken(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -198,7 +198,7 @@ func TestTokenSource_Token_NoExpiry(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -221,7 +221,7 @@ func TestTokenSource_Token_SubjectTokenProviderError(t *testing.T) { t.Parallel() providerErr := errors.New("failed to get token from provider") - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -251,7 +251,7 @@ func TestTokenSource_Token_ContextCancellation(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -800,7 +800,7 @@ func TestSubjectTokenProvider_Variants(t *testing.T) { })) defer server.Close() - config := &Config{ + config := &ExchangeConfig{ TokenURL: server.URL, ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -1036,10 +1036,10 @@ func TestExchangeToken_ScopeArray(t *testing.T) { } // TestConfig_TokenSource tests that TokenSource creates a valid tokenSource. -func TestConfig_TokenSource(t *testing.T) { +func TestExchangeConfig_TokenSource(t *testing.T) { t.Parallel() - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", @@ -1175,14 +1175,14 @@ func TestClientAuthentication_Fields(t *testing.T) { } // TestConfig_Fields tests Config struct fields. -func TestConfig_Fields(t *testing.T) { +func TestExchangeConfig_Fields(t *testing.T) { t.Parallel() provider := func() (string, error) { return "token", nil } - config := &Config{ + config := &ExchangeConfig{ TokenURL: "https://example.com/token", ClientID: "test-client-id", ClientSecret: "test-client-secret", diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go new file mode 100644 index 000000000..1dc245da4 --- /dev/null +++ b/pkg/auth/tokenexchange/middleware.go @@ -0,0 +1,239 @@ +package tokenexchange + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/golang-jwt/jwt/v5" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +// Middleware type constant +const ( + MiddlewareType = "tokenexchange" +) + +// Header injection strategy constants +const ( + // HeaderStrategyReplace replaces the Authorization header with the exchanged token + HeaderStrategyReplace = "replace" + // HeaderStrategyCustom adds the exchanged token to a custom header + HeaderStrategyCustom = "custom" +) + +var errUnknownStrategy = errors.New("unknown token injection strategy") + +// MiddlewareParams represents the parameters for token exchange middleware +type MiddlewareParams struct { + TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"` +} + +// Config holds configuration for token exchange middleware +type Config struct { + // TokenURL is the OAuth 2.0 token endpoint URL + TokenURL string `json:"token_url"` + + // ClientID is the OAuth 2.0 client identifier + ClientID string `json:"client_id"` + + // ClientSecret is the OAuth 2.0 client secret + ClientSecret string `json:"client_secret"` + + // Audience is the target audience for the exchanged token + Audience string `json:"audience"` + + // Scopes is the list of scopes to request for the exchanged token + Scopes []string `json:"scopes,omitempty"` + + // HeaderStrategy determines how to inject the token + // Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom + HeaderStrategy string `json:"header_strategy,omitempty"` + + // ExternalTokenHeaderName is the name of the custom header to use when HeaderStrategy is "custom" + ExternalTokenHeaderName string `json:"external_token_header_name,omitempty"` +} + +// Middleware wraps token exchange middleware functionality +type Middleware struct { + middleware types.MiddlewareFunction +} + +// Handler returns the middleware function used by the proxy. +func (m *Middleware) Handler() types.MiddlewareFunction { + return m.middleware +} + +// Close cleans up any resources used by the middleware. +func (*Middleware) Close() error { + // Token exchange middleware doesn't need cleanup + return nil +} + +// CreateMiddleware factory function for token exchange middleware +func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + var params MiddlewareParams + if err := json.Unmarshal(config.Parameters, ¶ms); err != nil { + return fmt.Errorf("failed to unmarshal token exchange middleware parameters: %w", err) + } + + // Token exchange config is required when this middleware type is specified + if params.TokenExchangeConfig == nil { + return fmt.Errorf("token exchange configuration is required but not provided") + } + + // Validate configuration + if err := validateTokenExchangeConfig(params.TokenExchangeConfig); err != nil { + return fmt.Errorf("invalid token exchange configuration: %w", err) + } + + middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig) + if err != nil { + return fmt.Errorf("invalid token exchange middleware config: %w", err) + } + + tokenExchangeMw := &Middleware{ + middleware: middleware, + } + + // Add middleware to runner + runner.AddMiddleware(tokenExchangeMw) + + return nil +} + +// validateTokenExchangeConfig validates the token exchange configuration +func validateTokenExchangeConfig(config *Config) error { + if config.HeaderStrategy == HeaderStrategyCustom && config.ExternalTokenHeaderName == "" { + return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom) + } + + if config.HeaderStrategy != "" && + config.HeaderStrategy != HeaderStrategyReplace && + config.HeaderStrategy != HeaderStrategyCustom { + return fmt.Errorf("invalid header_strategy: %s (valid values: '%s', '%s')", + config.HeaderStrategy, HeaderStrategyReplace, HeaderStrategyCustom) + } + + return nil +} + +// injectionFunc is a function that injects a token into an HTTP request +type injectionFunc func(*http.Request, string) error + +// createReplaceInjector creates an injection function that replaces the Authorization header +func createReplaceInjector() injectionFunc { + return func(r *http.Request, token string) error { + logger.Debugf("Token exchange successful, replacing Authorization header") + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + return nil + } +} + +// createCustomInjector creates an injection function that adds the token to a custom header +func createCustomInjector(headerName string) injectionFunc { + // Validate header name at creation time + if headerName == "" { + return func(_ *http.Request, _ string) error { + return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom) + } + } + + return func(r *http.Request, token string) error { + logger.Debugf("Token exchange successful, adding token to custom header: %s", headerName) + r.Header.Set(headerName, fmt.Sprintf("Bearer %s", token)) + return nil + } +} + +// CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims +// from the auth middleware to perform token exchange. +// This is a public function for direct usage in proxy commands. +func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) { + // Determine injection strategy at startup time + strategy := config.HeaderStrategy + if strategy == "" { + strategy = HeaderStrategyReplace // Default to replace for backwards compatibility + } + + var injectToken injectionFunc + switch strategy { + case HeaderStrategyReplace: + injectToken = createReplaceInjector() + case HeaderStrategyCustom: + injectToken = createCustomInjector(config.ExternalTokenHeaderName) + default: + return nil, fmt.Errorf("%w: invalid header injection strategy %s", errUnknownStrategy, strategy) + } + + // Create base exchange config at startup time with all static fields + baseExchangeConfig := ExchangeConfig{ + TokenURL: config.TokenURL, + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + Audience: config.Audience, + Scopes: config.Scopes, + // SubjectTokenProvider will be set per request + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get claims from the auth middleware + claims, ok := r.Context().Value(auth.ClaimsContextKey{}).(jwt.MapClaims) + if !ok { + logger.Debug("No claims found in context, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + // Extract the original token from the Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + logger.Debug("No valid Bearer token found, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + subjectToken := strings.TrimPrefix(authHeader, "Bearer ") + if subjectToken == "" { + logger.Debug("Empty Bearer token, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + // Log some claim information for debugging + if sub, exists := claims["sub"]; exists { + logger.Debugf("Performing token exchange for subject: %v", sub) + } + + // Create a copy of the base config with the request-specific subject token + exchangeConfig := baseExchangeConfig + exchangeConfig.SubjectTokenProvider = func() (string, error) { + return subjectToken, nil + } + + // Get token from token source + tokenSource := exchangeConfig.TokenSource(r.Context()) + exchangedToken, err := tokenSource.Token() + if err != nil { + logger.Warnf("Token exchange failed: %v", err) + http.Error(w, "Token exchange failed", http.StatusUnauthorized) + return + } + + // Inject the exchanged token into the request using the pre-selected strategy + if err := injectToken(r, exchangedToken.AccessToken); err != nil { + logger.Warnf("Failed to inject token: %v", err) + http.Error(w, "Token injection failed", http.StatusInternalServerError) + return + } + + next.ServeHTTP(w, r) + }) + }, nil +} diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go new file mode 100644 index 000000000..f4fd3439c --- /dev/null +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -0,0 +1,621 @@ +package tokenexchange + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/transport/types/mocks" +) + +// TestValidateTokenExchangeConfig tests configuration validation. +func TestValidateTokenExchangeConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + expectError bool + errorMsg string + }{ + { + name: "valid replace strategy explicit", + config: &Config{ + HeaderStrategy: HeaderStrategyReplace, + }, + expectError: false, + }, + { + name: "valid custom strategy with header name", + config: &Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-Upstream-Token", + }, + expectError: false, + }, + { + name: "valid empty strategy defaults to replace", + config: &Config{ + HeaderStrategy: "", + }, + expectError: false, + }, + { + name: "invalid custom strategy missing header name", + config: &Config{ + HeaderStrategy: HeaderStrategyCustom, + }, + expectError: true, + errorMsg: "external_token_header_name must be specified", + }, + { + name: "invalid strategy name", + config: &Config{ + HeaderStrategy: "invalid-strategy", + }, + expectError: true, + errorMsg: "invalid header_strategy", + }, + { + name: "unknown strategy", + config: &Config{ + HeaderStrategy: "query-param", + }, + expectError: true, + errorMsg: "invalid header_strategy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateTokenExchangeConfig(tt.config) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestInjectToken tests the token injection strategies. +func TestInjectToken(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + originalAuthHeader string + newToken string + expectError bool + errorMsg string + expectedAuthHeader string + expectedCustomHeader string + customHeaderName string + }{ + { + name: "replace strategy replaces Authorization header", + config: Config{ + HeaderStrategy: HeaderStrategyReplace, + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer new-token", + }, + { + name: "empty strategy defaults to replace", + config: Config{ + HeaderStrategy: "", + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer new-token", + }, + { + name: "custom strategy preserves original and adds custom header", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-Upstream-Token", + }, + originalAuthHeader: "Bearer original-token", + newToken: "new-token", + expectError: false, + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer new-token", + customHeaderName: "X-Upstream-Token", + }, + { + name: "custom strategy with different header name", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + ExternalTokenHeaderName: "X-External-Auth", + }, + originalAuthHeader: "Bearer original-token", + newToken: "exchanged-token", + expectError: false, + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer exchanged-token", + customHeaderName: "X-External-Auth", + }, + { + name: "custom strategy missing header name fails", + config: Config{ + HeaderStrategy: HeaderStrategyCustom, + }, + newToken: "new-token", + expectError: true, + errorMsg: "external_token_header_name must be specified", + }, + { + name: "unsupported strategy fails", + config: Config{ + HeaderStrategy: "unsupported-strategy", + }, + newToken: "new-token", + expectError: true, + errorMsg: "unsupported header_strategy", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tt.originalAuthHeader != "" { + req.Header.Set("Authorization", tt.originalAuthHeader) + } + + // Create the injector function based on the strategy (mimics CreateTokenExchangeMiddlewareFromClaims) + strategy := tt.config.HeaderStrategy + if strategy == "" { + strategy = HeaderStrategyReplace + } + + var injectToken injectionFunc + switch strategy { + case HeaderStrategyReplace: + injectToken = createReplaceInjector() + case HeaderStrategyCustom: + injectToken = createCustomInjector(tt.config.ExternalTokenHeaderName) + default: + injectToken = func(_ *http.Request, _ string) error { + return fmt.Errorf("unsupported header_strategy: %s (valid values: '%s', '%s')", + strategy, HeaderStrategyReplace, HeaderStrategyCustom) + } + } + + err := injectToken(req, tt.newToken) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedAuthHeader, req.Header.Get("Authorization")) + if tt.customHeaderName != "" { + assert.Equal(t, tt.expectedCustomHeader, req.Header.Get(tt.customHeaderName)) + } + } + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_Success tests successful token exchange flow. +func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headerStrategy string + customHeaderName string + scopes []string + expectedAuthHeader string + expectedCustomHeader string + expectedScopesReceived string + }{ + { + name: "replace strategy", + headerStrategy: HeaderStrategyReplace, + scopes: nil, + expectedAuthHeader: "Bearer exchanged-token", + expectedScopesReceived: "", + }, + { + name: "custom strategy", + headerStrategy: HeaderStrategyCustom, + customHeaderName: "X-Upstream-Token", + scopes: nil, + expectedAuthHeader: "Bearer original-token", + expectedCustomHeader: "Bearer exchanged-token", + expectedScopesReceived: "", + }, + { + name: "with scopes", + headerStrategy: HeaderStrategyReplace, + scopes: []string{"read", "write", "admin"}, + expectedAuthHeader: "Bearer exchanged-token", + expectedScopesReceived: "read write admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var receivedScopes string + + // Create mock OAuth server + exchangeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tt.expectedScopesReceived != "" { + _ = r.ParseForm() + receivedScopes = r.Form.Get("scope") + } + + resp := response{ + AccessToken: "exchanged-token", + TokenType: "Bearer", + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer exchangeServer.Close() + + config := Config{ + TokenURL: exchangeServer.URL, + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + Audience: "https://api.example.com", + Scopes: tt.scopes, + HeaderStrategy: tt.headerStrategy, + ExternalTokenHeaderName: tt.customHeaderName, + } + + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) + + // Test handler verifies token injection + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, tt.expectedAuthHeader, r.Header.Get("Authorization")) + if tt.customHeaderName != "" { + assert.Equal(t, tt.expectedCustomHeader, r.Header.Get(tt.customHeaderName)) + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with claims and token + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer original-token") + claims := jwt.MapClaims{ + "sub": "user123", + "aud": "test-audience", + } + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + req = req.WithContext(ctx) + + // Execute middleware + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + if tt.expectedScopesReceived != "" { + assert.Equal(t, tt.expectedScopesReceived, receivedScopes) + } + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_PassThrough tests cases where middleware passes through. +func TestCreateTokenExchangeMiddlewareFromClaims_PassThrough(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupReq func(*http.Request) *http.Request + description string + }{ + { + name: "no claims in context", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer original-token") + return req + }, + description: "should pass through without token exchange", + }, + { + name: "no Authorization header", + setupReq: func(req *http.Request) *http.Request { + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through without token exchange", + }, + { + name: "non-Bearer token", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through with non-Bearer auth", + }, + { + name: "empty Bearer token", + setupReq: func(req *http.Request) *http.Request { + req.Header.Set("Authorization", "Bearer ") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + return req.WithContext(ctx) + }, + description: "should pass through with empty Bearer token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + config := Config{ + TokenURL: "https://example.com/token", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + } + + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) + + handlerCalled := false + testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req = tt.setupReq(req) + + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code, tt.description) + assert.True(t, handlerCalled, "handler should be called") + }) + } +} + +// TestCreateTokenExchangeMiddlewareFromClaims_Failures tests error scenarios. +func TestCreateTokenExchangeMiddlewareFromClaims_Failures(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverResponse func(w http.ResponseWriter, r *http.Request) + headerStrategy string + customHeaderName string + expectedStatusCode int + expectedBodyMsg string + }{ + { + name: "token exchange returns 401", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_client"}`)) + }, + headerStrategy: HeaderStrategyReplace, + expectedStatusCode: http.StatusUnauthorized, + expectedBodyMsg: "Token exchange failed", + }, + { + name: "token exchange returns 500", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"server_error"}`)) + }, + headerStrategy: HeaderStrategyReplace, + expectedStatusCode: http.StatusUnauthorized, + expectedBodyMsg: "Token exchange failed", + }, + { + name: "invalid injection config", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + resp := response{ + AccessToken: "exchanged-token", + TokenType: "Bearer", + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + }, + headerStrategy: HeaderStrategyCustom, + customHeaderName: "", // Missing header name causes injection failure + expectedStatusCode: http.StatusInternalServerError, + expectedBodyMsg: "Token injection failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + exchangeServer := httptest.NewServer(http.HandlerFunc(tt.serverResponse)) + defer exchangeServer.Close() + + config := Config{ + TokenURL: exchangeServer.URL, + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + HeaderStrategy: tt.headerStrategy, + ExternalTokenHeaderName: tt.customHeaderName, + } + + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + require.NoError(t, err) + + testHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("handler should not be called on failure") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("Authorization", "Bearer original-token") + claims := jwt.MapClaims{"sub": "user123"} + ctx := context.WithValue(req.Context(), auth.ClaimsContextKey{}, claims) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + handler := middleware(testHandler) + handler.ServeHTTP(rec, req) + + assert.Equal(t, tt.expectedStatusCode, rec.Code) + assert.Contains(t, rec.Body.String(), tt.expectedBodyMsg) + }) + } +} + +// TestCreateMiddleware tests the factory function. +func TestCreateMiddleware(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params MiddlewareParams + expectError bool + errorMsg string + expectAddMiddleware bool + }{ + { + name: "valid config creates middleware", + params: MiddlewareParams{ + TokenExchangeConfig: &Config{ + TokenURL: "https://example.com/token", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + HeaderStrategy: HeaderStrategyReplace, + }, + }, + expectError: false, + expectAddMiddleware: true, + }, + { + name: "nil config returns error", + params: MiddlewareParams{ + TokenExchangeConfig: nil, + }, + expectError: true, + errorMsg: "token exchange configuration is required", + expectAddMiddleware: false, + }, + { + name: "invalid config fails validation", + params: MiddlewareParams{ + TokenExchangeConfig: &Config{ + HeaderStrategy: HeaderStrategyCustom, + // Missing ExternalTokenHeaderName + }, + }, + expectError: true, + errorMsg: "invalid token exchange configuration", + expectAddMiddleware: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + + if tt.expectAddMiddleware { + mockRunner.EXPECT().AddMiddleware(gomock.Any()).Do(func(mw types.Middleware) { + _, ok := mw.(*Middleware) + assert.True(t, ok, "Expected middleware to be of type *tokenexchange.Middleware") + }) + } + + paramsJSON, err := json.Marshal(tt.params) + require.NoError(t, err) + + config := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: paramsJSON, + } + + err = CreateMiddleware(config, mockRunner) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestCreateMiddleware_InvalidJSON tests error handling for malformed parameters. +func TestCreateMiddleware_InvalidJSON(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRunner := mocks.NewMockMiddlewareRunner(ctrl) + + config := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: []byte(`{invalid json}`), + } + + err := CreateMiddleware(config, mockRunner) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal token exchange middleware parameters") +} + +// TestMiddleware_Methods tests the Middleware struct methods. +func TestMiddleware_Methods(t *testing.T) { + t.Parallel() + + middlewareFunc := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + + mw := &Middleware{ + middleware: middlewareFunc, + } + + // Test Handler returns the function + handler := mw.Handler() + assert.NotNil(t, handler) + + // Test Close returns no error + err := mw.Close() + assert.NoError(t, err) +} diff --git a/pkg/authz/middleware_test.go b/pkg/authz/middleware_test.go index f13ede350..fba625a91 100644 --- a/pkg/authz/middleware_test.go +++ b/pkg/authz/middleware_test.go @@ -19,9 +19,9 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" mcpparser "github.com/stacklok/toolhive/pkg/mcp" - "github.com/stacklok/toolhive/pkg/testkit" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/transport/types/mocks" + "github.com/stacklok/toolhive/test/testkit" ) func TestMiddleware(t *testing.T) { diff --git a/pkg/mcp/tool_filter.go b/pkg/mcp/tool_filter.go index 6761d8d01..a5bad1b83 100644 --- a/pkg/mcp/tool_filter.go +++ b/pkg/mcp/tool_filter.go @@ -245,7 +245,10 @@ func NewToolCallMappingMiddleware(opts ...ToolMiddlewareOption) (types.Middlewar next.ServeHTTP(w, r) return } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // TODO: find a reasonable way to test this + r.ContentLength = int64(len(bodyBytes)) // According to the current version of the MCP spec at // https://modelcontextprotocol.io/specification/2025-06-18/schema#calltoolrequest diff --git a/pkg/mcp/tool_middleware_test.go b/pkg/mcp/tool_middleware_test.go index f019239f7..3de2ec3fc 100644 --- a/pkg/mcp/tool_middleware_test.go +++ b/pkg/mcp/tool_middleware_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stacklok/toolhive/pkg/testkit" + "github.com/stacklok/toolhive/test/testkit" ) func TestNewListToolsMappingMiddleware_Scenarios(t *testing.T) { diff --git a/pkg/registry/data/registry.json b/pkg/registry/data/registry.json index 5657151a4..933da3b40 100644 --- a/pkg/registry/data/registry.json +++ b/pkg/registry/data/registry.json @@ -1,7 +1,7 @@ { "$schema": "https://raw.githubusercontent.com/stacklok/toolhive/main/pkg/registry/data/schema.json", "version": "1.0.0", - "last_updated": "2025-10-08T00:17:09Z", + "last_updated": "2025-10-09T00:16:54Z", "servers": { "adb-mysql-mcp-server": { "description": "Official MCP server for AnalyticDB for MySQL of Alibaba Cloud", @@ -711,7 +711,7 @@ "database", "storage" ], - "image": "mcr.microsoft.com/azure-sdk/azure-mcp:0.8.4", + "image": "mcr.microsoft.com/azure-sdk/azure-mcp:0.8.5", "permissions": { "network": { "outbound": { @@ -1189,7 +1189,7 @@ }, "context7": { "description": "Context7 MCP pulls version-specific docs and code examples directly into your prompt", - "tier": "Community", + "tier": "Official", "status": "Active", "transport": "stdio", "tools": [ @@ -1197,26 +1197,36 @@ "get-library-docs" ], "metadata": { - "stars": 31260, + "stars": 32836, "pulls": 313, - "last_updated": "2025-09-24T02:28:46Z" + "last_updated": "2025-10-07T20:17:57Z" }, "repository_url": "https://github.com/upstash/context7", "tags": [ "documentation", - "modelcontextprotocol" + "code-examples" ], - "image": "ghcr.io/stacklok/dockyard/npx/context7:1.0.20", + "image": "ghcr.io/stacklok/dockyard/npx/context7:1.0.21", "permissions": { "network": { "outbound": { - "insecure_allow_all": true, + "allow_host": [ + "context7.com" + ], "allow_port": [ 443 ] } } }, + "env_vars": [ + { + "name": "CONTEXT7_API_KEY", + "description": "API key for higher rate limits", + "required": false, + "secret": true + } + ], "provenance": { "sigstore_url": "tuf-repo-cdn.sigstore.dev", "repository_uri": "https://github.com/stacklok/dockyard", @@ -5065,7 +5075,7 @@ "metadata": { "stars": 0, "pulls": 0, - "last_updated": "2025-09-10T02:27:19Z" + "last_updated": "2025-10-08T02:28:06Z" }, "tags": [ "remote", @@ -5113,7 +5123,7 @@ "metadata": { "stars": 0, "pulls": 0, - "last_updated": "2025-09-10T02:27:19Z" + "last_updated": "2025-10-08T02:28:06Z" }, "tags": [ "remote", @@ -5133,6 +5143,31 @@ }, "url": "https://mcp.canva.com/mcp" }, + "context7-remote": { + "description": "Remote Context7 MCP server pulls version-specific docs and code examples directly into your prompt", + "tier": "Official", + "status": "Active", + "transport": "streamable-http", + "tools": [ + "resolve-library-id", + "get-library-docs" + ], + "metadata": { + "stars": 31260, + "pulls": 313, + "last_updated": "2025-09-24T02:28:46Z" + }, + "repository_url": "https://github.com/upstash/context7", + "tags": [ + "documentation", + "code-examples" + ], + "custom_metadata": { + "author": "Upstash", + "homepage": "https://context7.com/" + }, + "url": "https://mcp.context7.com/mcp" + }, "github-remote": { "description": "GitHub's official MCP server for repositories, issues, PRs, actions, and security with OAuth", "tier": "Official", @@ -5232,9 +5267,9 @@ "update_pull_request_branch" ], "metadata": { - "stars": 22439, + "stars": 23322, "pulls": 0, - "last_updated": "2025-09-10T02:27:19Z" + "last_updated": "2025-10-08T02:28:06Z" }, "repository_url": "https://github.com/github/github-mcp-server", "tags": [ @@ -5284,7 +5319,7 @@ "metadata": { "stars": 0, "pulls": 0, - "last_updated": "2025-09-10T02:27:19Z" + "last_updated": "2025-10-08T02:28:06Z" }, "tags": [ "remote", @@ -5319,7 +5354,7 @@ "metadata": { "stars": 0, "pulls": 0, - "last_updated": "2025-09-10T02:27:19Z" + "last_updated": "2025-10-08T02:28:06Z" }, "tags": [ "remote", diff --git a/pkg/telemetry/middleware.go b/pkg/telemetry/middleware.go index 19e943bc3..42dce7235 100644 --- a/pkg/telemetry/middleware.go +++ b/pkg/telemetry/middleware.go @@ -1,9 +1,11 @@ package telemetry import ( + "bufio" "context" "encoding/json" "fmt" + "net" "net/http" "os" "strconv" @@ -130,6 +132,7 @@ func (m *HTTPMiddleware) Handler(next http.Handler) http.Handler { ResponseWriter: w, statusCode: http.StatusOK, bytesWritten: 0, + wroteHeader: false, } // Add HTTP attributes @@ -403,21 +406,45 @@ type responseWriter struct { http.ResponseWriter statusCode int bytesWritten int64 + wroteHeader bool } // WriteHeader captures the status code. func (rw *responseWriter) WriteHeader(statusCode int) { + if rw.wroteHeader { + logger.Infof("WriteHeader called multiple times: attempted status %d, already wrote status %d", statusCode, rw.statusCode) + return // Prevent multiple WriteHeader calls + } rw.statusCode = statusCode + rw.wroteHeader = true rw.ResponseWriter.WriteHeader(statusCode) } // Write captures the number of bytes written. func (rw *responseWriter) Write(data []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } n, err := rw.ResponseWriter.Write(data) rw.bytesWritten += int64(n) return n, err } +// Flush implements http.Flusher interface +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Hijack implements http.Hijacker interface +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, fmt.Errorf("underlying http.ResponseWriter does not implement http.Hijacker") +} + // recordMetrics records request metrics. func (m *HTTPMiddleware) recordMetrics(ctx context.Context, r *http.Request, rw *responseWriter, duration time.Duration) { // Get MCP method from context if available diff --git a/pkg/telemetry/middleware_test.go b/pkg/telemetry/middleware_test.go index 5a7a36822..f9aa81898 100644 --- a/pkg/telemetry/middleware_test.go +++ b/pkg/telemetry/middleware_test.go @@ -1491,3 +1491,320 @@ func TestFactoryMiddleware_Integration(t *testing.T) { assert.NoError(t, err) }) } + +// TestResponseWriter_WriteHeader tests the WriteHeader method +func TestResponseWriter_WriteHeader(t *testing.T) { + t.Parallel() + + t.Run("sets status code correctly", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + rw.WriteHeader(http.StatusCreated) + + assert.Equal(t, http.StatusCreated, rw.statusCode) + assert.True(t, rw.wroteHeader) + assert.Equal(t, http.StatusCreated, rec.Code) + }) + + t.Run("prevents duplicate WriteHeader calls", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + rw.WriteHeader(http.StatusCreated) + rw.WriteHeader(http.StatusBadRequest) // Should be ignored + + assert.Equal(t, http.StatusCreated, rw.statusCode, "Status code should not change after first write") + assert.True(t, rw.wroteHeader) + assert.Equal(t, http.StatusCreated, rec.Code) + }) +} + +// TestResponseWriter_Write tests the Write method +func TestResponseWriter_Write(t *testing.T) { + t.Parallel() + + t.Run("writes data and tracks bytes", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + data := []byte("Hello, World!") + n, err := rw.Write(data) + + assert.NoError(t, err) + assert.Equal(t, len(data), n) + assert.Equal(t, int64(len(data)), rw.bytesWritten) + assert.Equal(t, "Hello, World!", rec.Body.String()) + }) + + t.Run("automatically writes header on first Write", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + _, err := rw.Write([]byte("test")) + + assert.NoError(t, err) + assert.True(t, rw.wroteHeader) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("accumulates bytes written", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + rw.Write([]byte("Hello")) + rw.Write([]byte(", ")) + rw.Write([]byte("World!")) + + assert.Equal(t, int64(13), rw.bytesWritten) + assert.Equal(t, "Hello, World!", rec.Body.String()) + }) +} + +// TestResponseWriter_Flush tests the Flush method +func TestResponseWriter_Flush(t *testing.T) { + t.Parallel() + + t.Run("calls Flush on underlying Flusher", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + // Write some data + rw.Write([]byte("test data")) + + // Flush should not panic even though httptest.ResponseRecorder implements Flusher + assert.NotPanics(t, func() { + rw.Flush() + }) + }) + + t.Run("handles non-Flusher ResponseWriter gracefully", func(t *testing.T) { + t.Parallel() + // Create a minimal ResponseWriter that doesn't implement Flusher + minimalWriter := &minimalResponseWriter{ + header: make(http.Header), + body: []byte{}, + } + + rw := &responseWriter{ + ResponseWriter: minimalWriter, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + // Flush should not panic when underlying writer doesn't support it + assert.NotPanics(t, func() { + rw.Flush() + }) + }) +} + +// TestResponseWriter_Hijack tests the Hijack method +func TestResponseWriter_Hijack(t *testing.T) { + t.Parallel() + + t.Run("returns error when Hijacker not supported", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + conn, buf, err := rw.Hijack() + + assert.Error(t, err) + assert.Nil(t, conn) + assert.Nil(t, buf) + assert.Contains(t, err.Error(), "http.Hijacker") + }) +} + +// TestResponseWriter_HeadersIntegration tests that headers work correctly +func TestResponseWriter_HeadersIntegration(t *testing.T) { + t.Parallel() + + t.Run("headers are set before WriteHeader", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + // Set headers before writing + rw.Header().Set("Content-Type", "application/json") + rw.Header().Set("X-Custom-Header", "test-value") + rw.WriteHeader(http.StatusCreated) + + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.Equal(t, "test-value", rec.Header().Get("X-Custom-Header")) + assert.Equal(t, http.StatusCreated, rec.Code) + }) + + t.Run("headers are preserved with Write", func(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: rec, + statusCode: http.StatusOK, + bytesWritten: 0, + wroteHeader: false, + } + + // Set headers + rw.Header().Set("X-Session-Id", "test-session-123") + rw.Header().Set("Content-Type", "application/json") + + // Write data (should auto-call WriteHeader) + rw.Write([]byte(`{"status":"ok"}`)) + + assert.Equal(t, "test-session-123", rec.Header().Get("X-Session-Id")) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.Equal(t, `{"status":"ok"}`, rec.Body.String()) + }) +} + +// TestResponseWriter_WithMiddlewareChain tests responseWriter in a middleware chain +func TestResponseWriter_WithMiddlewareChain(t *testing.T) { + t.Parallel() + + config := Config{ + ServiceName: "test-service", + ServiceVersion: "1.0.0", + } + tracerProvider := tracenoop.NewTracerProvider() + meterProvider := noop.NewMeterProvider() + + middleware := NewHTTPMiddleware(config, tracerProvider, meterProvider, "test-server", "stdio") + + t.Run("headers set by handler are preserved", func(t *testing.T) { + t.Parallel() + + // Create a handler that sets a custom header + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-Test-Header", "middleware-test") + w.Header().Set("Mcp-Session-Id", "session-12345") + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + // Wrap with telemetry middleware + wrappedHandler := middleware(handler) + + // Create test request + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + + // Execute request + wrappedHandler.ServeHTTP(rec, req) + + // Verify headers are preserved + assert.Equal(t, "middleware-test", rec.Header().Get("X-Test-Header")) + assert.Equal(t, "session-12345", rec.Header().Get("Mcp-Session-Id")) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "success", rec.Body.String()) + }) + + t.Run("multiple writes work correctly", func(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Part 1, ")) + w.Write([]byte("Part 2, ")) + w.Write([]byte("Part 3")) + }) + + wrappedHandler := middleware(handler) + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(rec, req) + + assert.Equal(t, "Part 1, Part 2, Part 3", rec.Body.String()) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("error status codes are captured", func(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + }) + + wrappedHandler := middleware(handler) + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Equal(t, "not found", rec.Body.String()) + }) +} + +// minimalResponseWriter is a minimal implementation of http.ResponseWriter +// that doesn't implement Flusher or Hijacker interfaces +type minimalResponseWriter struct { + header http.Header + body []byte + statusCode int +} + +func (m *minimalResponseWriter) Header() http.Header { + return m.header +} + +func (m *minimalResponseWriter) Write(data []byte) (int, error) { + m.body = append(m.body, data...) + return len(data), nil +} + +func (m *minimalResponseWriter) WriteHeader(statusCode int) { + m.statusCode = statusCode +} diff --git a/pkg/testkit/sse_server.go b/test/testkit/sse_server.go similarity index 100% rename from pkg/testkit/sse_server.go rename to test/testkit/sse_server.go diff --git a/pkg/testkit/streamable_server.go b/test/testkit/streamable_server.go similarity index 100% rename from pkg/testkit/streamable_server.go rename to test/testkit/streamable_server.go diff --git a/pkg/testkit/testkit.go b/test/testkit/testkit.go similarity index 100% rename from pkg/testkit/testkit.go rename to test/testkit/testkit.go diff --git a/pkg/testkit/testkit_test.go b/test/testkit/testkit_test.go similarity index 100% rename from pkg/testkit/testkit_test.go rename to test/testkit/testkit_test.go