diff --git a/src/auth_server/routers/interaction.py b/src/auth_server/routers/interaction.py index 49cf75d..59a7b7b 100644 --- a/src/auth_server/routers/interaction.py +++ b/src/auth_server/routers/interaction.py @@ -7,16 +7,16 @@ from fastapi import APIRouter, BackgroundTasks, Form, HTTPException from loguru import logger from starlette.responses import HTMLResponse, RedirectResponse, Response +from starlette.templating import Jinja2Templates from auth_server.config import load_config from auth_server.context import ContextRequest, ContextRequestRoute from auth_server.db.transaction_state import FlowState, TransactionState, get_transaction_state_db from auth_server.models.gnap import FinishInteractionMethod -from auth_server.templating import TestableJinja2Templates from auth_server.utils import get_interaction_hash, push_interaction_finish interaction_router = APIRouter(route_class=ContextRequestRoute, prefix="/interaction") -templates = TestableJinja2Templates(directory=str(Path(__file__).with_name("templates"))) +templates = Jinja2Templates(directory=str(Path(__file__).with_name("templates"))) @interaction_router.get("/redirect/{transaction_id}", response_class=HTMLResponse) diff --git a/src/auth_server/routers/saml2_sp.py b/src/auth_server/routers/saml2_sp.py index 8d53b25..5bff4b0 100644 --- a/src/auth_server/routers/saml2_sp.py +++ b/src/auth_server/routers/saml2_sp.py @@ -10,6 +10,7 @@ from saml2.metadata import entity_descriptor from saml2.response import StatusError from starlette.responses import HTMLResponse, RedirectResponse +from starlette.templating import Jinja2Templates from auth_server.config import load_config from auth_server.context import ContextRequest, ContextRequestRoute @@ -23,10 +24,9 @@ get_saml2_sp, process_assertion, ) -from auth_server.templating import TestableJinja2Templates saml2_router = APIRouter(route_class=ContextRequestRoute, prefix="/saml2") -templates = TestableJinja2Templates(directory=str(Path(__file__).with_name("templates"))) +templates = Jinja2Templates(directory=str(Path(__file__).with_name("templates"))) @saml2_router.get("/sp/authn/{transaction_id}", response_class=HTMLResponse) diff --git a/src/auth_server/templating.py b/src/auth_server/templating.py deleted file mode 100644 index 96922d5..0000000 --- a/src/auth_server/templating.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Mapping, Optional - -from starlette.responses import Response -from starlette.templating import Jinja2Templates as _Jinja2Templates -from starlette.templating import _TemplateResponse - -__author__ = "lundberg" - - -# Workaround for bug in Starlette. -# https://github.com/encode/starlette/issues/472#issuecomment-612398116 - - -class TestableJinja2Templates(_Jinja2Templates): - def TemplateResponse( - self, - name: str, - context: dict, - status_code: int = 200, - headers: Optional[Mapping[str, str]] = None, - media_type: Optional[str] = None, - background=None, - ) -> _TemplateResponse: - if "request" not in context: - raise ValueError('context must include a "request" key') - template = self.get_template(name) - return CustomTemplateResponse( - template, - context, - status_code=status_code, - headers=headers, - media_type=media_type, - background=background, - ) - - -class CustomTemplateResponse(_TemplateResponse): - async def __call__(self, scope, receive, send) -> None: - # context sending removed - await Response.__call__(self, scope, receive, send)