diff --git a/README.md b/README.md index 8707e6c..60f2e63 100644 --- a/README.md +++ b/README.md @@ -1073,6 +1073,74 @@ $server = Server::make() ->build(); ``` +### Middleware Support + +Both `HttpServerTransport` and `StreamableHttpServerTransport` support PSR-7 compatible middleware for intercepting and modifying HTTP requests and responses. Middleware allows you to extract common functionality like authentication, logging, CORS handling, and request validation into reusable components. + +Middleware must be a valid PHP callable that accepts a PSR-7 `ServerRequestInterface` as the first argument and a `callable` as the second argument. + +```php +use Psr\Http\Message\ServerRequestInterface; +use Psr\Http\Message\ResponseInterface; +use React\Promise\PromiseInterface; + +class AuthMiddleware +{ + public function __invoke(ServerRequestInterface $request, callable $next) + { + $apiKey = $request->getHeaderLine('Authorization'); + if (empty($apiKey)) { + return new Response(401, [], 'Authorization required'); + } + + $request = $request->withAttribute('user_id', $this->validateApiKey($apiKey)); + $result = $next($request); + + return match (true) { + $result instanceof PromiseInterface => $result->then(fn($response) => $this->handle($response)), + $result instanceof ResponseInterface => $this->handle($result), + default => $result + }; + } + + private function handle($response) + { + return $response instanceof ResponseInterface + ? $response->withHeader('X-Auth-Provider', 'mcp-server') + : $response; + } +} + +$middlewares = [ + new AuthMiddleware(), + new LoggingMiddleware(), + function(ServerRequestInterface $request, callable $next) { + $result = $next($request); + return match (true) { + $result instanceof PromiseInterface => $result->then(function($response) { + return $response instanceof ResponseInterface + ? $response->withHeader('Access-Control-Allow-Origin', '*') + : $response; + }), + $result instanceof ResponseInterface => $result->withHeader('Access-Control-Allow-Origin', '*'), + default => $result + }; + } +]; + +$transport = new StreamableHttpServerTransport( + host: '127.0.0.1', + port: 8080, + middlewares: $middlewares +); +``` + +**Important Considerations:** + +- **Response Handling**: Middleware must handle both synchronous `ResponseInterface` and asynchronous `PromiseInterface` returns from `$next($request)`, since ReactPHP operates asynchronously +- **Invokable Pattern**: The recommended pattern is to use invokable classes with a separate `handle()` method to process responses, making the async logic reusable +- **Execution Order**: Middleware executes in the order provided, with the last middleware being closest to your MCP handlers + ### SSL Context Configuration For HTTPS deployments of `StreamableHttpServerTransport`, configure SSL context options: diff --git a/src/Transports/HttpServerTransport.php b/src/Transports/HttpServerTransport.php index 1965704..c609dab 100644 --- a/src/Transports/HttpServerTransport.php +++ b/src/Transports/HttpServerTransport.php @@ -62,17 +62,25 @@ class HttpServerTransport implements ServerTransportInterface, LoggerAwareInterf * @param int $port Port to listen on (e.g., 8080). * @param string $mcpPathPrefix URL prefix for MCP endpoints (e.g., 'mcp'). * @param array|null $sslContext Optional SSL context options for React SocketServer (for HTTPS). + * @param array $middlewares Middlewares to be applied to the HTTP server. */ public function __construct( private readonly string $host = '127.0.0.1', private readonly int $port = 8080, private readonly string $mcpPathPrefix = 'mcp', private readonly ?array $sslContext = null, + private array $middlewares = [] ) { $this->logger = new NullLogger(); $this->loop = Loop::get(); $this->ssePath = '/' . trim($mcpPathPrefix, '/') . '/sse'; $this->messagePath = '/' . trim($mcpPathPrefix, '/') . '/message'; + + foreach ($this->middlewares as $mw) { + if (!is_callable($mw)) { + throw new \InvalidArgumentException('All provided middlewares must be callable.'); + } + } } public function setLogger(LoggerInterface $logger): void @@ -114,7 +122,8 @@ public function listen(): void $this->loop ); - $this->http = new HttpServer($this->loop, $this->createRequestHandler()); + $handlers = array_merge($this->middlewares, [$this->createRequestHandler()]); + $this->http = new HttpServer($this->loop, ...$handlers); $this->http->listen($this->socket); $this->socket->on('error', function (Throwable $error) { @@ -261,7 +270,10 @@ protected function handleMessagePostRequest(ServerRequestInterface $request): Re return new Response(400, ['Content-Type' => 'application/json'], json_encode($error, $jsonEncodeFlags)); } - $this->emit('message', [$message, $sessionId]); + $context = [ + 'request' => $request, + ]; + $this->emit('message', [$message, $sessionId, $context]); return new Response(202, ['Content-Type' => 'text/plain'], 'Accepted'); } diff --git a/src/Transports/StreamableHttpServerTransport.php b/src/Transports/StreamableHttpServerTransport.php index a836e71..9d9a349 100644 --- a/src/Transports/StreamableHttpServerTransport.php +++ b/src/Transports/StreamableHttpServerTransport.php @@ -67,6 +67,9 @@ class StreamableHttpServerTransport implements ServerTransportInterface, LoggerA /** * @param bool $enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream. + * @param bool $stateless If true, the server will not emit client_connected events. + * @param EventStoreInterface $eventStore If provided, the server will replay events to the client. + * @param array $middlewares Middlewares to be applied to the HTTP server. * This can be useful for simple request/response scenarios without streaming. */ public function __construct( @@ -76,12 +79,19 @@ public function __construct( private ?array $sslContext = null, private readonly bool $enableJsonResponse = true, private readonly bool $stateless = false, - ?EventStoreInterface $eventStore = null + ?EventStoreInterface $eventStore = null, + private array $middlewares = [] ) { $this->logger = new NullLogger(); $this->loop = Loop::get(); $this->mcpPath = '/' . trim($mcpPath, '/'); $this->eventStore = $eventStore; + + foreach ($this->middlewares as $mw) { + if (!is_callable($mw)) { + throw new \InvalidArgumentException('All provided middlewares must be callable.'); + } + } } protected function generateId(): string @@ -119,7 +129,8 @@ public function listen(): void $this->loop ); - $this->http = new HttpServer($this->loop, $this->createRequestHandler()); + $handlers = array_merge($this->middlewares, [$this->createRequestHandler()]); + $this->http = new HttpServer($this->loop, ...$handlers); $this->http->listen($this->socket); $this->socket->on('error', function (Throwable $error) { diff --git a/tests/Fixtures/General/RequestAttributeChecker.php b/tests/Fixtures/General/RequestAttributeChecker.php new file mode 100644 index 0000000..a79861a --- /dev/null +++ b/tests/Fixtures/General/RequestAttributeChecker.php @@ -0,0 +1,21 @@ +request->getAttribute('middleware-attr'); + if ($attribute === 'middleware-value') { + return TextContent::make('middleware-value-found: ' . $attribute); + } + + return TextContent::make('middleware-value-not-found: ' . $attribute); + } +} diff --git a/tests/Fixtures/Middlewares/ErrorMiddleware.php b/tests/Fixtures/Middlewares/ErrorMiddleware.php new file mode 100644 index 0000000..f2be79e --- /dev/null +++ b/tests/Fixtures/Middlewares/ErrorMiddleware.php @@ -0,0 +1,18 @@ +getUri()->getPath(), '/error-middleware')) { + throw new \Exception('Middleware error'); + } + return $next($request); + } +} diff --git a/tests/Fixtures/Middlewares/FirstMiddleware.php b/tests/Fixtures/Middlewares/FirstMiddleware.php new file mode 100644 index 0000000..31fb153 --- /dev/null +++ b/tests/Fixtures/Middlewares/FirstMiddleware.php @@ -0,0 +1,33 @@ + $result->then(fn($response) => $this->handle($response)), + $result instanceof ResponseInterface => $this->handle($result), + default => $result + }; + } + + private function handle($response) + { + if ($response instanceof ResponseInterface) { + $existing = $response->getHeaderLine('X-Middleware-Order'); + $new = $existing ? $existing . ',first' : 'first'; + return $response->withHeader('X-Middleware-Order', $new); + } + return $response; + } +} diff --git a/tests/Fixtures/Middlewares/HeaderMiddleware.php b/tests/Fixtures/Middlewares/HeaderMiddleware.php new file mode 100644 index 0000000..bc8a456 --- /dev/null +++ b/tests/Fixtures/Middlewares/HeaderMiddleware.php @@ -0,0 +1,30 @@ + $result->then(fn($response) => $this->handle($response)), + $result instanceof ResponseInterface => $this->handle($result), + default => $result + }; + } + + private function handle($response) + { + return $response instanceof ResponseInterface + ? $response->withHeader('X-Test-Middleware', 'header-added') + : $response; + } +} diff --git a/tests/Fixtures/Middlewares/RequestAttributeMiddleware.php b/tests/Fixtures/Middlewares/RequestAttributeMiddleware.php new file mode 100644 index 0000000..2b07f7a --- /dev/null +++ b/tests/Fixtures/Middlewares/RequestAttributeMiddleware.php @@ -0,0 +1,16 @@ +withAttribute('middleware-attr', 'middleware-value'); + return $next($request); + } +} diff --git a/tests/Fixtures/Middlewares/SecondMiddleware.php b/tests/Fixtures/Middlewares/SecondMiddleware.php new file mode 100644 index 0000000..275746e --- /dev/null +++ b/tests/Fixtures/Middlewares/SecondMiddleware.php @@ -0,0 +1,33 @@ + $result->then(fn($response) => $this->handle($response)), + $result instanceof ResponseInterface => $this->handle($result), + default => $result + }; + } + + private function handle($response) + { + if ($response instanceof ResponseInterface) { + $existing = $response->getHeaderLine('X-Middleware-Order'); + $new = $existing ? $existing . ',second' : 'second'; + return $response->withHeader('X-Middleware-Order', $new); + } + return $response; + } +} diff --git a/tests/Fixtures/Middlewares/ShortCircuitMiddleware.php b/tests/Fixtures/Middlewares/ShortCircuitMiddleware.php new file mode 100644 index 0000000..04ef518 --- /dev/null +++ b/tests/Fixtures/Middlewares/ShortCircuitMiddleware.php @@ -0,0 +1,19 @@ +getUri()->getPath(), '/short-circuit')) { + return new Response(418, [], 'Short-circuited by middleware'); + } + return $next($request); + } +} diff --git a/tests/Fixtures/Middlewares/ThirdMiddleware.php b/tests/Fixtures/Middlewares/ThirdMiddleware.php new file mode 100644 index 0000000..fe647c3 --- /dev/null +++ b/tests/Fixtures/Middlewares/ThirdMiddleware.php @@ -0,0 +1,33 @@ + $result->then(fn($response) => $this->handle($response)), + $result instanceof ResponseInterface => $this->handle($result), + default => $result + }; + } + + private function handle($response) + { + if ($response instanceof ResponseInterface) { + $existing = $response->getHeaderLine('X-Middleware-Order'); + $new = $existing ? $existing . ',third' : 'third'; + return $response->withHeader('X-Middleware-Order', $new); + } + return $response; + } +} diff --git a/tests/Fixtures/ServerScripts/HttpTestServer.php b/tests/Fixtures/ServerScripts/HttpTestServer.php index 07d66ae..9bd6b86 100755 --- a/tests/Fixtures/ServerScripts/HttpTestServer.php +++ b/tests/Fixtures/ServerScripts/HttpTestServer.php @@ -10,6 +10,14 @@ use PhpMcp\Server\Tests\Fixtures\General\ToolHandlerFixture; use PhpMcp\Server\Tests\Fixtures\General\ResourceHandlerFixture; use PhpMcp\Server\Tests\Fixtures\General\PromptHandlerFixture; +use PhpMcp\Server\Tests\Fixtures\General\RequestAttributeChecker; +use PhpMcp\Server\Tests\Fixtures\Middlewares\HeaderMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\RequestAttributeMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ShortCircuitMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\FirstMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\SecondMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ThirdMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ErrorMiddleware; use Psr\Log\AbstractLogger; use Psr\Log\NullLogger; @@ -32,11 +40,22 @@ public function log($level, \Stringable|string $message, array $context = []): v ->withServerInfo('HttpIntegrationTestServer', '0.1.0') ->withLogger($logger) ->withTool([ToolHandlerFixture::class, 'greet'], 'greet_http_tool') + ->withTool([RequestAttributeChecker::class, 'checkAttribute'], 'check_request_attribute_tool') ->withResource([ResourceHandlerFixture::class, 'getStaticText'], "test://http/static", 'static_http_resource') ->withPrompt([PromptHandlerFixture::class, 'generateSimpleGreeting'], 'simple_http_prompt') ->build(); - $transport = new HttpServerTransport($host, $port, $mcpPathPrefix); + $middlewares = [ + new HeaderMiddleware(), + new RequestAttributeMiddleware(), + new ShortCircuitMiddleware(), + new FirstMiddleware(), + new SecondMiddleware(), + new ThirdMiddleware(), + new ErrorMiddleware() + ]; + + $transport = new HttpServerTransport($host, $port, $mcpPathPrefix, null, $middlewares); $server->listen($transport); exit(0); diff --git a/tests/Fixtures/ServerScripts/StreamableHttpTestServer.php b/tests/Fixtures/ServerScripts/StreamableHttpTestServer.php index b7cdf6a..2c85436 100755 --- a/tests/Fixtures/ServerScripts/StreamableHttpTestServer.php +++ b/tests/Fixtures/ServerScripts/StreamableHttpTestServer.php @@ -10,6 +10,14 @@ use PhpMcp\Server\Tests\Fixtures\General\ToolHandlerFixture; use PhpMcp\Server\Tests\Fixtures\General\ResourceHandlerFixture; use PhpMcp\Server\Tests\Fixtures\General\PromptHandlerFixture; +use PhpMcp\Server\Tests\Fixtures\General\RequestAttributeChecker; +use PhpMcp\Server\Tests\Fixtures\Middlewares\HeaderMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\RequestAttributeMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ShortCircuitMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\FirstMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\SecondMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ThirdMiddleware; +use PhpMcp\Server\Tests\Fixtures\Middlewares\ErrorMiddleware; use PhpMcp\Server\Defaults\InMemoryEventStore; use Psr\Log\AbstractLogger; use Psr\Log\NullLogger; @@ -41,17 +49,29 @@ public function log($level, \Stringable|string $message, array $context = []): v ->withTool([ToolHandlerFixture::class, 'greet'], 'greet_streamable_tool') ->withTool([ToolHandlerFixture::class, 'sum'], 'sum_streamable_tool') // For batch testing ->withTool([ToolHandlerFixture::class, 'toolReadsContext'], 'tool_reads_context') // for Context testing + ->withTool([RequestAttributeChecker::class, 'checkAttribute'], 'check_request_attribute_tool') ->withResource([ResourceHandlerFixture::class, 'getStaticText'], "test://streamable/static", 'static_streamable_resource') ->withPrompt([PromptHandlerFixture::class, 'generateSimpleGreeting'], 'simple_streamable_prompt') ->build(); + $middlewares = [ + new HeaderMiddleware(), + new RequestAttributeMiddleware(), + new ShortCircuitMiddleware(), + new FirstMiddleware(), + new SecondMiddleware(), + new ThirdMiddleware(), + new ErrorMiddleware() + ]; + $transport = new StreamableHttpServerTransport( host: $host, port: $port, mcpPath: $mcpPath, enableJsonResponse: $enableJsonResponse, stateless: $stateless, - eventStore: $eventStore + eventStore: $eventStore, + middlewares: $middlewares ); $server->listen($transport); diff --git a/tests/Integration/HttpServerTransportTest.php b/tests/Integration/HttpServerTransportTest.php index dab32bd..44534b4 100644 --- a/tests/Integration/HttpServerTransportTest.php +++ b/tests/Integration/HttpServerTransportTest.php @@ -259,7 +259,7 @@ expect($toolListResponse['id'])->toBe('tool-list-http-1'); expect($toolListResponse)->not->toHaveKey('error'); - expect($toolListResponse['result']['tools'])->toBeArray()->toHaveCount(1); + expect($toolListResponse['result']['tools'])->toBeArray()->toHaveCount(2); expect($toolListResponse['result']['tools'][0]['name'])->toBe('greet_http_tool'); $this->sseClient->close(); @@ -417,3 +417,96 @@ $this->fail("Request to unknown path failed with unexpected error: " . $e->getMessage()); } })->group('integration', 'http_transport'); + +it('executes middleware that adds headers to response', function () { + $this->sseClient = new MockSseClient(); + $sseBaseUrl = "http://" . HTTP_SERVER_HOST . ":" . $this->port . "/" . HTTP_MCP_PATH_PREFIX . "/sse"; + + // 1. Connect + await($this->sseClient->connect($sseBaseUrl)); + await(delay(0.05, $this->loop)); + + // 2. Check that the middleware-added header is present in the response + expect($this->sseClient->lastConnectResponse->getHeaderLine('X-Test-Middleware'))->toBe('header-added'); + + $this->sseClient->close(); +})->group('integration', 'http_transport', 'middleware'); + +it('executes middleware that modifies request attributes', function () { + $this->sseClient = new MockSseClient(); + $sseBaseUrl = "http://" . HTTP_SERVER_HOST . ":" . $this->port . "/" . HTTP_MCP_PATH_PREFIX . "/sse"; + + // 1. Connect + await($this->sseClient->connect($sseBaseUrl)); + await(delay(0.05, $this->loop)); + + // 2. Initialize + await($this->sseClient->sendHttpRequest('init-middleware-attr', 'initialize', [ + 'protocolVersion' => Protocol::LATEST_PROTOCOL_VERSION, + 'clientInfo' => ['name' => 'MiddlewareTestClient'], + 'capabilities' => [] + ])); + await($this->sseClient->getNextMessageResponse('init-middleware-attr')); + await($this->sseClient->sendHttpNotification('notifications/initialized')); + await(delay(0.05, $this->loop)); + + // 3. Call tool that checks for middleware-added attribute + await($this->sseClient->sendHttpRequest('tool-attr-check', 'tools/call', [ + 'name' => 'check_request_attribute_tool', + 'arguments' => [] + ])); + $toolResponse = await($this->sseClient->getNextMessageResponse('tool-attr-check')); + + expect($toolResponse['result']['content'][0]['text'])->toBe('middleware-value-found: middleware-value'); + + $this->sseClient->close(); +})->group('integration', 'http_transport', 'middleware'); + +it('executes middleware that can short-circuit request processing', function () { + $browser = new Browser($this->loop); + $shortCircuitUrl = "http://" . HTTP_SERVER_HOST . ":" . $this->port . "/" . HTTP_MCP_PATH_PREFIX . "/short-circuit"; + + $promise = $browser->get($shortCircuitUrl); + + try { + $response = await(timeout($promise, HTTP_PROCESS_TIMEOUT_SECONDS - 2, $this->loop)); + $this->fail("Expected a 418 status code response, but request succeeded"); + } catch (ResponseException $e) { + expect($e->getResponse()->getStatusCode())->toBe(418); + $body = (string) $e->getResponse()->getBody(); + expect($body)->toBe('Short-circuited by middleware'); + } catch (\Throwable $e) { + $this->fail("Short-circuit middleware test failed: " . $e->getMessage()); + } +})->group('integration', 'http_transport', 'middleware'); + +it('executes multiple middlewares in correct order', function () { + $this->sseClient = new MockSseClient(); + $sseBaseUrl = "http://" . HTTP_SERVER_HOST . ":" . $this->port . "/" . HTTP_MCP_PATH_PREFIX . "/sse"; + + // 1. Connect + await($this->sseClient->connect($sseBaseUrl)); + await(delay(0.05, $this->loop)); + + // 2. Check that headers from multiple middlewares are present in correct order + expect($this->sseClient->lastConnectResponse->getHeaderLine('X-Middleware-Order'))->toBe('third,second,first'); + + $this->sseClient->close(); +})->group('integration', 'http_transport', 'middleware'); + +it('handles middleware that throws exceptions gracefully', function () { + $browser = new Browser($this->loop); + $errorUrl = "http://" . HTTP_SERVER_HOST . ":" . $this->port . "/" . HTTP_MCP_PATH_PREFIX . "/error-middleware"; + + $promise = $browser->get($errorUrl); + + try { + await(timeout($promise, HTTP_PROCESS_TIMEOUT_SECONDS - 2, $this->loop)); + $this->fail("Error middleware should have thrown an exception."); + } catch (ResponseException $e) { + expect($e->getResponse()->getStatusCode())->toBe(500); + $body = (string) $e->getResponse()->getBody(); + // ReactPHP handles exceptions and returns a generic error message + expect($body)->toContain('Internal Server Error'); + } +})->group('integration', 'http_transport', 'middleware'); diff --git a/tests/Integration/StreamableHttpServerTransportTest.php b/tests/Integration/StreamableHttpServerTransportTest.php index f7b7ebb..6d88c8d 100644 --- a/tests/Integration/StreamableHttpServerTransportTest.php +++ b/tests/Integration/StreamableHttpServerTransportTest.php @@ -217,7 +217,7 @@ expect($toolListResult['statusCode'])->toBe(200); expect($toolListResult['body']['id'])->toBe('tool-list-json-1'); expect($toolListResult['body']['result']['tools'])->toBeArray(); - expect(count($toolListResult['body']['result']['tools']))->toBe(3); + expect(count($toolListResult['body']['result']['tools']))->toBe(4); expect($toolListResult['body']['result']['tools'][0]['name'])->toBe('greet_streamable_tool'); expect($toolListResult['body']['result']['tools'][1]['name'])->toBe('sum_streamable_tool'); expect($toolListResult['body']['result']['tools'][2]['name'])->toBe('tool_reads_context'); @@ -460,7 +460,7 @@ expect($toolListResponse['id'])->toBe('tool-list-stream-1'); expect($toolListResponse)->not->toHaveKey('error'); expect($toolListResponse['result']['tools'])->toBeArray(); - expect(count($toolListResponse['result']['tools']))->toBe(3); + expect(count($toolListResponse['result']['tools']))->toBe(4); expect($toolListResponse['result']['tools'][0]['name'])->toBe('greet_streamable_tool'); expect($toolListResponse['result']['tools'][1]['name'])->toBe('sum_streamable_tool'); expect($toolListResponse['result']['tools'][2]['name'])->toBe('tool_reads_context'); @@ -663,7 +663,7 @@ expect($toolListResult['body']['id'])->toBe('tool-list-stateless-1'); expect($toolListResult['body'])->not->toHaveKey('error'); expect($toolListResult['body']['result']['tools'])->toBeArray(); - expect(count($toolListResult['body']['result']['tools']))->toBe(3); + expect(count($toolListResult['body']['result']['tools']))->toBe(4); expect($toolListResult['body']['result']['tools'][0]['name'])->toBe('greet_streamable_tool'); expect($toolListResult['body']['result']['tools'][1]['name'])->toBe('sum_streamable_tool'); expect($toolListResult['body']['result']['tools'][2]['name'])->toBe('tool_reads_context'); @@ -862,3 +862,103 @@ expect($decodedBody['error']['message'])->toContain('Invalid or expired session'); } })->group('integration', 'streamable_http_json'); + +it('executes middleware that adds headers to response', function () { + $this->process = new Process($this->jsonModeCommand, getcwd() ?: null, null, []); + $this->process->start(); + $this->jsonClient = new MockJsonHttpClient(STREAMABLE_HTTP_HOST, $this->port, STREAMABLE_MCP_PATH); + await(delay(0.1)); + + // 1. Send a request and check that middleware-added header is present + $response = await($this->jsonClient->sendRequest('initialize', [ + 'protocolVersion' => Protocol::LATEST_PROTOCOL_VERSION, + 'clientInfo' => ['name' => 'MiddlewareTestClient'], + 'capabilities' => [] + ], 'init-middleware-headers')); + + // Check that the response has the header added by middleware + expect($this->jsonClient->lastResponseHeaders)->toContain('X-Test-Middleware: header-added'); +})->group('integration', 'streamable_http', 'middleware'); + +it('executes middleware that modifies request attributes', function () { + $this->process = new Process($this->jsonModeCommand, getcwd() ?: null, null, []); + $this->process->start(); + $this->jsonClient = new MockJsonHttpClient(STREAMABLE_HTTP_HOST, $this->port, STREAMABLE_MCP_PATH); + await(delay(0.1)); + + // 1. Initialize + await($this->jsonClient->sendRequest('initialize', [ + 'protocolVersion' => Protocol::LATEST_PROTOCOL_VERSION, + 'clientInfo' => ['name' => 'MiddlewareAttrTestClient', 'version' => '1.0'], + 'capabilities' => [] + ], 'init-middleware-attr')); + await($this->jsonClient->sendNotification('notifications/initialized')); + + // 2. Call tool that checks for middleware-added attribute + $toolResponse = await($this->jsonClient->sendRequest('tools/call', [ + 'name' => 'check_request_attribute_tool', + 'arguments' => [] + ], 'tool-attr-check')); + + expect($toolResponse['body']['result']['content'][0]['text'])->toBe('middleware-value-found: middleware-value'); +})->group('integration', 'streamable_http', 'middleware'); + +it('executes middleware that can short-circuit request processing', function () { + $this->process = new Process($this->jsonModeCommand, getcwd() ?: null, null, []); + $this->process->start(); + await(delay(0.1)); + + $browser = new Browser(); + $shortCircuitUrl = "http://" . STREAMABLE_HTTP_HOST . ":" . $this->port . "/" . STREAMABLE_MCP_PATH . "/short-circuit"; + + $promise = $browser->get($shortCircuitUrl); + + try { + $response = await(timeout($promise, STREAMABLE_HTTP_PROCESS_TIMEOUT - 2)); + $this->fail("Expected a 418 status code response, but request succeeded"); + } catch (ResponseException $e) { + expect($e->getResponse()->getStatusCode())->toBe(418); + $body = (string) $e->getResponse()->getBody(); + expect($body)->toBe('Short-circuited by middleware'); + } catch (\Throwable $e) { + $this->fail("Short-circuit middleware test failed: " . $e->getMessage()); + } +})->group('integration', 'streamable_http', 'middleware'); + +it('executes multiple middlewares in correct order', function () { + $this->process = new Process($this->jsonModeCommand, getcwd() ?: null, null, []); + $this->process->start(); + $this->jsonClient = new MockJsonHttpClient(STREAMABLE_HTTP_HOST, $this->port, STREAMABLE_MCP_PATH); + await(delay(0.1)); + + // 1. Send a request and check middleware order + await($this->jsonClient->sendRequest('initialize', [ + 'protocolVersion' => Protocol::LATEST_PROTOCOL_VERSION, + 'clientInfo' => ['name' => 'MiddlewareOrderTestClient'], + 'capabilities' => [] + ], 'init-middleware-order')); + + // Check that headers from multiple middlewares are present in correct order + expect($this->jsonClient->lastResponseHeaders)->toContain('X-Middleware-Order: third,second,first'); +})->group('integration', 'streamable_http', 'middleware'); + +it('handles middleware that throws exceptions gracefully', function () { + $this->process = new Process($this->jsonModeCommand, getcwd() ?: null, null, []); + $this->process->start(); + await(delay(0.1)); + + $browser = new Browser(); + $errorUrl = "http://" . STREAMABLE_HTTP_HOST . ":" . $this->port . "/" . STREAMABLE_MCP_PATH . "/error-middleware"; + + $promise = $browser->get($errorUrl); + + try { + await(timeout($promise, STREAMABLE_HTTP_PROCESS_TIMEOUT - 2)); + $this->fail("Error middleware should have thrown an exception."); + } catch (ResponseException $e) { + expect($e->getResponse()->getStatusCode())->toBe(500); + $body = (string) $e->getResponse()->getBody(); + // ReactPHP handles exceptions and returns a generic error message + expect($body)->toContain('Internal Server Error'); + } +})->group('integration', 'streamable_http', 'middleware'); diff --git a/tests/Mocks/Clients/MockJsonHttpClient.php b/tests/Mocks/Clients/MockJsonHttpClient.php index 364c90b..cae47d7 100644 --- a/tests/Mocks/Clients/MockJsonHttpClient.php +++ b/tests/Mocks/Clients/MockJsonHttpClient.php @@ -13,6 +13,7 @@ class MockJsonHttpClient public Browser $browser; public string $baseUrl; public ?string $sessionId = null; + public array $lastResponseHeaders = []; // Store last response headers for testing public function __construct(string $host, int $port, string $mcpPath, int $timeout = 2) { @@ -37,6 +38,14 @@ public function sendRequest(string $method, array $params = [], ?string $id = nu return $this->browser->post($this->baseUrl, $headers, $body) ->then(function (ResponseInterface $response) use ($method) { + // Store response headers for testing + $this->lastResponseHeaders = []; + foreach ($response->getHeaders() as $name => $values) { + foreach ($values as $value) { + $this->lastResponseHeaders[] = "{$name}: {$value}"; + } + } + $bodyContent = (string) $response->getBody()->getContents(); $statusCode = $response->getStatusCode(); diff --git a/tests/Mocks/Clients/MockSseClient.php b/tests/Mocks/Clients/MockSseClient.php index ad54374..1423a6e 100644 --- a/tests/Mocks/Clients/MockSseClient.php +++ b/tests/Mocks/Clients/MockSseClient.php @@ -22,6 +22,7 @@ class MockSseClient private array $receivedSseEvents = []; // Stores raw SSE events (type, data, id) public ?string $endpointUrl = null; // The /message endpoint URL provided by server public ?string $clientId = null; // The clientId from the /message endpoint URL + public ?ResponseInterface $lastConnectResponse = null; // Last connect response for header testing public function __construct(int $timeout = 2) { @@ -32,6 +33,7 @@ public function connect(string $sseBaseUrl): PromiseInterface { return $this->browser->requestStreaming('GET', $sseBaseUrl) ->then(function (ResponseInterface $response) { + $this->lastConnectResponse = $response; // Store response for header testing if ($response->getStatusCode() !== 200) { $body = (string) $response->getBody(); throw new \RuntimeException("SSE connection failed with status {$response->getStatusCode()}: {$body}");