Skip to content

Fix: Prevent duplicate root_path in pagination links #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions stac_fastapi/pgstac/models/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,30 @@ def base_url(self):
@property
def url(self):
"""Get the current request url."""
base_url = self.request.base_url
path = self.request.url.path

# root path can be set in the request scope in two different ways:
# - by uvicorn when running with --root-path
# - by FastAPI when running with FastAPI(root_path="...")
#
# When root path is set by uvicorn, request.url.path will have the root path prefix.
# eg. if root path is "/api" and the path is "/collections",
# the request.url.path will be "/api/collections"
#
# We need to remove the root path prefix from the path before
# joining the base_url and path to get the full url to avoid
# having root_path twice in the url
if (
root_path := self.request.scope.get("root_path")
) and not self.request.app.root_path:
# self.request.app.root_path is set by FastAPI when running with FastAPI(root_path="...")
# If self.request.app.root_path is not set but self.request.scope.get("root_path") is set,
# then the root path is set by uvicorn
# So we need to remove the root path prefix from the path before
# joining the base_url and path to get the full url
if path.startswith(root_path):
path = path[len(root_path) :]

url = urljoin(str(base_url), path.lstrip("/"))
# self.base_url calls stac_fastapi.types.requests.get_base_url which accounts for ROOT_PATH env var.
# Ensure it has a trailing slash for urljoin.
final_base_url = str(self.base_url)
if not final_base_url.endswith("/"):
final_base_url += "/"

# Path of the current request from the server root, e.g., /env_root/items/1
current_item_path = self.request.url.path

# Get the path component of final_base_url, e.g., /env_root/ (guaranteed by previous step)
final_base_url_path_part = urlparse(final_base_url).path

# Make current_item_path relative to final_base_url_path_part
# e.g., if current_item_path is /env_root/items/1 and final_base_url_path_part is /env_root/,
# then path_relative_to_base becomes items/1
path_relative_to_base = current_item_path
if current_item_path.startswith(final_base_url_path_part):
path_relative_to_base = current_item_path[len(final_base_url_path_part):]

# urljoin will combine them correctly: urljoin("http://server/env_root/", "items/1")
url = urljoin(final_base_url, path_relative_to_base)

if qs := self.request.url.query:
url += f"?{qs}"

return url

def resolve(self, url):
Expand Down
65 changes: 65 additions & 0 deletions tests/api/test_links.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import os
from unittest import mock
from urllib.parse import urlparse

import pytest
from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from starlette.testclient import TestClient

# Assuming app is defined in stac_fastapi.pgstac.app
# If not, this import will need adjustment.
from stac_fastapi.pgstac.app import app
from stac_fastapi.pgstac.models import links as app_links


Expand Down Expand Up @@ -71,6 +78,64 @@ async def collections(request: Request):
assert link["href"].startswith(url_prefix)
assert {"next", "previous", "root", "self"} == {link["rel"] for link in links}


# The load_test_data fixture is assumed to be defined in conftest.py
# and to load enough items for pagination to occur with limit=1.
@mock.patch.dict(os.environ, {"ROOT_PATH": "/custom/api/root"})
def test_pagination_link_with_root_path(load_test_data):
"""Test that pagination links are correct when ROOT_PATH is set."""
# get_base_url directly uses os.getenv("ROOT_PATH"), so patching
# os.environ should be effective for new requests.
# The TestClient is initialized after the patch.

# Use the global `app` imported from stac_fastapi.pgstac.app
# The TestClient for STAC FastAPI typically uses http://testserver as base
with TestClient(app, base_url="http://testserver") as client:
# Perform a search that should result in a 'next' link
# Assumes load_test_data has loaded more than 1 item
response = client.get("/search?limit=1")
response_json = response.json()

assert response.status_code == 200, f"Response content: {response.text}"
next_link = None
for link in response_json.get("links", []):
if link.get("rel") == "next":
next_link = link
break

assert next_link is not None, "Next link not found in response"

href = next_link["href"]

# Expected: http://testserver/custom/api/root/search?limit=1&token=next:...
# Not: http://testserver/custom/api/root/custom/api/root/search?...

parsed_href = urlparse(href)
path = parsed_href.path

# Check that the path starts correctly with the root path and the endpoint
expected_start_path = "/custom/api/root/search"
assert path.startswith(expected_start_path), f"Path {path} does not start with {expected_start_path}"

# Check that the root path segment is not duplicated
# e.g. path should not be /custom/api/root/custom/api/root/search
duplicated_root_path = "/custom/api/root/custom/api/root/"
assert not path.startswith(duplicated_root_path), f"Path {path} shows duplicated root path starting with {duplicated_root_path}"

# A more precise check for occurrences of the root path segments
# Path: /custom/api/root/search
# Root Path: /custom/api/root
# Effective path segments to check for duplication: custom/api/root
path_segments = path.strip('/').split('/')
root_path_segments_to_check = "custom/api/root".split('/')

occurrences = 0
for i in range(len(path_segments) - len(root_path_segments_to_check) + 1):
if path_segments[i:i+len(root_path_segments_to_check)] == root_path_segments_to_check:
occurrences += 1

assert occurrences == 1, f"Expected ROOT_PATH segments to appear once, but found {occurrences} times in path {path}. Segments: {path_segments}"

response = client.get(f"{prefix}/search", params={"limit": 1})
assert response.status_code == 200
assert response.json()["url"] == url_prefix + "/search?limit=1"
Expand Down