diff --git a/src/CsrfMiddleware.php b/src/CsrfMiddleware.php index cffab6c..e6671f3 100644 --- a/src/CsrfMiddleware.php +++ b/src/CsrfMiddleware.php @@ -11,6 +11,7 @@ use Psr\Http\Server\RequestHandlerInterface; use Yiisoft\Http\Method; use Yiisoft\Http\Status; +use Yiisoft\Csrf\CsrfTrait; use function in_array; use function is_string; @@ -21,12 +22,8 @@ * @link https://www.php-fig.org/psr/psr-15/ */ final class CsrfMiddleware implements MiddlewareInterface -{ - public const PARAMETER_NAME = '_csrf'; - public const HEADER_NAME = 'X-CSRF-Token'; - - private string $parameterName = self::PARAMETER_NAME; - private string $headerName = self::HEADER_NAME; +{ + use CsrfTrait; private ResponseFactoryInterface $responseFactory; private CsrfTokenInterface $token; @@ -59,29 +56,7 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface return $response; } - public function withParameterName(string $name): self - { - $new = clone $this; - $new->parameterName = $name; - return $new; - } - - public function withHeaderName(string $name): self - { - $new = clone $this; - $new->headerName = $name; - return $new; - } - - public function getParameterName(): string - { - return $this->parameterName; - } - - public function getHeaderName(): string - { - return $this->headerName; - } + private function validateCsrfToken(ServerRequestInterface $request): bool { @@ -98,7 +73,7 @@ private function getTokenFromRequest(ServerRequestInterface $request): ?string { $parsedBody = $request->getParsedBody(); - $token = $parsedBody[$this->parameterName] ?? null; + $token = $parsedBody[$this->formParameterName] ?? null; if (empty($token)) { $headers = $request->getHeader($this->headerName); $token = reset($headers); diff --git a/src/CsrfTrait b/src/CsrfTrait new file mode 100644 index 0000000..5b17919 --- /dev/null +++ b/src/CsrfTrait @@ -0,0 +1,46 @@ +formParameterName = $name; + return $new; + } + + public function withHeaderName(string $name): self + { + $new = clone $this; + $new->headerName = $name; + return $new; + } + + public function getFormParameterName(): string + { + return $this->formParameterName; + } + + public function getHeaderName(): string + { + return $this->headerName; + } + + +}