Skip to content

Move BackgroundTask execution outside of request/response cycle #2176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -96,6 +97,7 @@ def build_middleware_stack(self) -> ASGIApp:

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
Comment on lines 99 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
[
Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug),
Middleware(BackgroundTaskMiddleware),
]

+ self.user_middleware
+ [
Middleware(
Expand Down
37 changes: 37 additions & 0 deletions starlette/middleware/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List, cast

from starlette.background import BackgroundTask
from starlette.types import ASGIApp, Receive, Scope, Send

# consider this a private implementation detail subject to change
# do not rely on this key
_SCOPE_KEY = "starlette._background"


_BackgroundTaskList = List[BackgroundTask]


def is_background_task_middleware_installed(scope: Scope) -> bool:
return _SCOPE_KEY in scope


def add_tasks(scope: Scope, task: BackgroundTask, /) -> None:
if _SCOPE_KEY not in scope: # pragma: no cover
raise RuntimeError(
"`add_tasks` can only be used if `BackgroundTaskMIddleware is installed"
)
Comment on lines +19 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you trigger this, tho? If we always call is_background_task_middleware_installed before the add_tasks, then this will never be called. I guess the point you are trying to make here is to make this as public API, but enforce it for others.

We have some scenarios:

  1. You are running Starlette(), which you always have the BackgroundTasksMiddleware - in this case we wouldn't need to check the scopes, because we know the scope[_SCOPE_KEY] is there.
  2. You are running only parts of Starlette e.g. Response - in this case, you don't have the BackgroundTasksMiddleware, which means this would be a breaking change - oh... Now I see what you did. You kept the previous behavior where it runs the background task if standalone Response with background tasks is used.

Ok... As you see, I was figuring stuff as I was writing this to myself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that now the logic is different on those scenarios. I'm having a hard time figuring out how to avoid it tho...

We are in a similar situation with the exception of middleware.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually... If the _SCOPE_KEY is not present, we can error as above, which will make it clear to users that they need to wrap their responses with BackgroundTasksMiddleware.

What I mean is to not check if the BackgroundTasksMiddleware is installed, and error if user tries to use it outside the Starlette application.


Maybe that's your plan but with a deprecation warning?

cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(task)


class BackgroundTaskMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
tasks: _BackgroundTaskList
scope[_SCOPE_KEY] = tasks = []
try:
await self._app(scope, receive, send)
finally:
for task in tasks:
await task()
Comment on lines +35 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Background tasks should only run on 2xx.

I guess you'll need a wrap_send to get the status code.

19 changes: 19 additions & 0 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.middleware import background
from starlette.types import Receive, Scope, Send


Expand Down Expand Up @@ -148,6 +149,12 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the scenario you are trying to prevent when setting background to None? What if a user has a global response object?

response = Response(background=BackgroundTasks(...))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see now. You are doing because you didn't remove the logic below where we run the background task.

I think we can remove that logic, and you can avoid this line.

(Is my understanding correct?)

prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
Expand Down Expand Up @@ -255,6 +262,12 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
Expand Down Expand Up @@ -322,6 +335,12 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
self.headers.setdefault("etag", etag)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
Expand Down
104 changes: 99 additions & 5 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
Literal,
)

import anyio
import pytest
from anyio.abc import TaskStatus

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.conftest import TestClientFactory


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -372,8 +372,8 @@ async def send(message: Message) -> None:
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
"Background task started",
"Background task started",
"Background task finished",
"Background task started",
"Background task finished",
]

Expand Down Expand Up @@ -1035,3 +1035,97 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]


@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: list[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("called")

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
await disconnected.wait()
while True:
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
while True:
msg = yield
if msg["type"] == "http.response.body" and not msg.get("more_body", False):
disconnected.set()

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]


@pytest.mark.anyio
async def test_background_tasks_failure(
test_client_factory: TestClientFactory,
anyio_backend_name: Literal["asyncio", "trio"],
) -> None:
if anyio_backend_name == "trio":
pytest.skip("this test hangs with trio")

# test for https://github.com/encode/starlette/discussions/2640
container: list[str] = []

async def task1() -> None:
container.append("task1 called")
raise ValueError("task1 failed")

async def task2() -> None:
container.append("task2 called") # pragma: no cover

async def endpoint(request: Request) -> Response:
background = BackgroundTasks()
background.add_task(task1)
background.add_task(task2)
return PlainTextResponse("hi!", background=background)

async def dispatch(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = Starlette(
routes=[Route("/", endpoint)],
middleware=[Middleware(BaseHTTPMiddleware, dispatch=dispatch)],
)

client = test_client_factory(app, raise_server_exceptions=False)

response = client.get("/")
assert response.status_code == 200
assert response.text == "hi!"

assert container == ["task1 called"]
Loading