Skip to content

Commit

Permalink
fix(file response): Support varying mtime semantics across differen…
Browse files Browse the repository at this point in the history
…t fsspec implementations (litestar-org#3902)
  • Loading branch information
provinzkraut authored Dec 15, 2024
1 parent b30c7fc commit 87f8aa0
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 6 deletions.
48 changes: 43 additions & 5 deletions litestar/response/file.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import itertools
from datetime import datetime
from email.utils import formatdate
from inspect import iscoroutine
from mimetypes import encodings_map, guess_type
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Iterable, Literal, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Final, Iterable, Literal, cast
from urllib.parse import quote
from zlib import adler32

Expand Down Expand Up @@ -69,7 +70,7 @@ async def async_file_iterator(
yield chunk


def create_etag_for_file(path: PathType, modified_time: float, file_size: int) -> str:
def create_etag_for_file(path: PathType, modified_time: float | None, file_size: int) -> str:
"""Create an etag.
Notes:
Expand All @@ -79,7 +80,42 @@ def create_etag_for_file(path: PathType, modified_time: float, file_size: int) -
An etag.
"""
check = adler32(str(path).encode("utf-8")) & 0xFFFFFFFF
return f'"{modified_time}-{file_size}-{check}"'
parts = [str(file_size), str(check)]
if modified_time:
parts.insert(0, str(modified_time))
return f'"{"-".join(parts)}"'


_MTIME_KEYS: Final = (
"mtime",
"ctime",
"Last-Modified",
"updated_at",
"modification_time",
"last_changed",
"change_time",
"last_modified",
"last_updated",
"timestamp",
)


def get_fsspec_mtime_equivalent(info: dict[str, Any]) -> float | None:
"""Return the 'mtime' or equivalent for different fsspec implementations, since they
are not standardized.
See https://github.com/fsspec/filesystem_spec/issues/526.
"""
# inspired by https://github.com/mdshw5/pyfaidx/blob/cac82f24e9c4e334cf87a92e477b92d4615d260f/pyfaidx/__init__.py#L1318-L1345
mtime: Any | None = next((info[key] for key in _MTIME_KEYS if key in info), None)
if mtime is None or isinstance(mtime, float):
return mtime
if isinstance(mtime, datetime):
return mtime.timestamp()
if isinstance(mtime, str):
return datetime.fromisoformat(mtime.replace("Z", "+00:00")).timestamp()

raise ValueError(f"Unsupported mtime-type value type {type(mtime)!r}")


class ASGIFileResponse(ASGIStreamingResponse):
Expand Down Expand Up @@ -217,14 +253,16 @@ async def start_response(self, send: Send) -> None:
self.content_length = fs_info["size"]

self.headers.setdefault("content-length", str(self.content_length))
self.headers.setdefault("last-modified", formatdate(fs_info["mtime"], usegmt=True))
mtime = get_fsspec_mtime_equivalent(fs_info) # type: ignore[arg-type]
if mtime is not None:
self.headers.setdefault("last-modified", formatdate(mtime, usegmt=True))

if self.etag:
self.headers.setdefault("etag", self.etag.to_header())
else:
self.headers.setdefault(
"etag",
create_etag_for_file(path=self.file_path, modified_time=fs_info["mtime"], file_size=fs_info["size"]),
create_etag_for_file(path=self.file_path, modified_time=mtime, file_size=fs_info["size"]),
)

await super().start_response(send=send)
Expand Down
115 changes: 114 additions & 1 deletion tests/unit/test_response/test_file_response.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone
from email.utils import formatdate
from os import stat, urandom
from pathlib import Path
Expand All @@ -13,7 +14,7 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.file_system import BaseLocalFileSystem, FileSystemAdapter
from litestar.response.file import ASGIFileResponse, File, async_file_iterator
from litestar.status_codes import HTTP_200_OK
from litestar.status_codes import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing import create_test_client
from litestar.types import FileSystemProtocol

Expand Down Expand Up @@ -95,6 +96,118 @@ def handler() -> File:
assert response.headers["last-modified"].lower() == formatdate(path.stat().st_mtime, usegmt=True).lower()


@pytest.mark.parametrize(
"mtime,expected_last_modified",
[
pytest.param(
datetime(2000, 1, 2, 3, 4, 5, tzinfo=timezone.utc).timestamp(),
"Sun, 02 Jan 2000 03:04:05 GMT",
id="timestamp",
),
pytest.param(
datetime(2000, 1, 2, 3, 4, 5, tzinfo=timezone.utc), "Sun, 02 Jan 2000 03:04:05 GMT", id="datetime"
),
pytest.param(
datetime(2000, 1, 2, 3, 4, 5, tzinfo=timezone.utc).isoformat(),
"Sun, 02 Jan 2000 03:04:05 GMT",
id="isoformat",
),
],
)
@pytest.mark.parametrize(
"mtime_key",
[
"mtime",
"ctime",
"Last-Modified",
"updated_at",
"modification_time",
"last_changed",
"change_time",
"last_modified",
"last_updated",
"timestamp",
],
)
def test_file_response_last_modified_file_info_formats(
tmpdir: Path, mtime: Any, mtime_key: str, expected_last_modified: str
) -> None:
path = Path(tmpdir / "file.txt")
path.write_bytes(b"")
file_info = {"name": "file.txt", "size": 0, "type": "file", mtime_key: mtime}

@get("/")
def handler() -> File:
return File(
path=path,
filename="image.png",
file_info=file_info, # type: ignore[arg-type]
)

with create_test_client(handler) as client:
response = client.get("/")
assert response.status_code == HTTP_200_OK
assert response.headers["last-modified"].lower() == expected_last_modified.lower()


def test_file_response_last_modified_unsupported_mtime_type(tmpdir: Path) -> None:
path = Path(tmpdir / "file.txt")
path.write_bytes(b"")
file_info = {"name": "file.txt", "size": 0, "type": "file", "last_updated": object()}

@get("/")
def handler() -> File:
return File(
path=path,
filename="image.png",
file_info=file_info, # type: ignore[arg-type]
)

with create_test_client(handler) as client:
response = client.get("/")
assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR
assert "last-modified" not in response.headers


def test_file_response_last_modified_mtime_not_given(tmpdir: Path) -> None:
path = Path(tmpdir / "file.txt")
path.write_bytes(b"")
file_info = {"name": "file.txt", "size": 0, "type": "file"}

@get("/")
def handler() -> File:
return File(
path=path,
filename="image.png",
file_info=file_info, # type: ignore[arg-type]
)

with create_test_client(handler) as client:
response = client.get("/")
assert response.status_code == HTTP_200_OK
assert "last-modified" not in response.headers


def test_file_response_etag_without_mtime(tmpdir: Path) -> None:
path = Path(tmpdir / "file.txt")
path.write_bytes(b"")
file_info = {"name": "file.txt", "size": 0, "type": "file"}

@get("/")
def handler() -> File:
return File(
path=path,
filename="image.png",
file_info=file_info, # type: ignore[arg-type]
)

with create_test_client(handler) as client:
response = client.get("/")
assert response.status_code == HTTP_200_OK
# we expect etag to only have 2 parts here because no mtime was given
assert len(response.headers.get("etag", "").split("-")) == 2


async def test_file_response_with_directory_raises_error(tmpdir: Path) -> None:
with pytest.raises(ImproperlyConfiguredException):
asgi_response = ASGIFileResponse(file_path=tmpdir, filename="example.png")
Expand Down

0 comments on commit 87f8aa0

Please sign in to comment.