diff --git a/stac_fastapi/pgstac/models/links.py b/stac_fastapi/pgstac/models/links.py index ae7e3cf..2b6d9dd 100644 --- a/stac_fastapi/pgstac/models/links.py +++ b/stac_fastapi/pgstac/models/links.py @@ -51,35 +51,30 @@ def base_url(self): @property def url(self): """Get the current request url.""" - base_url = self.request.base_url - path = self.request.url.path - - # root path can be set in the request scope in two different ways: - # - by uvicorn when running with --root-path - # - by FastAPI when running with FastAPI(root_path="...") - # - # When root path is set by uvicorn, request.url.path will have the root path prefix. - # eg. if root path is "/api" and the path is "/collections", - # the request.url.path will be "/api/collections" - # - # We need to remove the root path prefix from the path before - # joining the base_url and path to get the full url to avoid - # having root_path twice in the url - if ( - root_path := self.request.scope.get("root_path") - ) and not self.request.app.root_path: - # self.request.app.root_path is set by FastAPI when running with FastAPI(root_path="...") - # If self.request.app.root_path is not set but self.request.scope.get("root_path") is set, - # then the root path is set by uvicorn - # So we need to remove the root path prefix from the path before - # joining the base_url and path to get the full url - if path.startswith(root_path): - path = path[len(root_path) :] - - url = urljoin(str(base_url), path.lstrip("/")) + # self.base_url calls stac_fastapi.types.requests.get_base_url which accounts for ROOT_PATH env var. + # Ensure it has a trailing slash for urljoin. + final_base_url = str(self.base_url) + if not final_base_url.endswith("/"): + final_base_url += "/" + + # Path of the current request from the server root, e.g., /env_root/items/1 + current_item_path = self.request.url.path + + # Get the path component of final_base_url, e.g., /env_root/ (guaranteed by previous step) + final_base_url_path_part = urlparse(final_base_url).path + + # Make current_item_path relative to final_base_url_path_part + # e.g., if current_item_path is /env_root/items/1 and final_base_url_path_part is /env_root/, + # then path_relative_to_base becomes items/1 + path_relative_to_base = current_item_path + if current_item_path.startswith(final_base_url_path_part): + path_relative_to_base = current_item_path[len(final_base_url_path_part):] + + # urljoin will combine them correctly: urljoin("http://server/env_root/", "items/1") + url = urljoin(final_base_url, path_relative_to_base) + if qs := self.request.url.query: url += f"?{qs}" - return url def resolve(self, url): diff --git a/tests/api/test_links.py b/tests/api/test_links.py index e8e57a9..624fbf1 100644 --- a/tests/api/test_links.py +++ b/tests/api/test_links.py @@ -1,8 +1,15 @@ +import os +from unittest import mock +from urllib.parse import urlparse + import pytest from fastapi import APIRouter, FastAPI from starlette.requests import Request from starlette.testclient import TestClient +# Assuming app is defined in stac_fastapi.pgstac.app +# If not, this import will need adjustment. +from stac_fastapi.pgstac.app import app from stac_fastapi.pgstac.models import links as app_links @@ -71,6 +78,64 @@ async def collections(request: Request): assert link["href"].startswith(url_prefix) assert {"next", "previous", "root", "self"} == {link["rel"] for link in links} + +# The load_test_data fixture is assumed to be defined in conftest.py +# and to load enough items for pagination to occur with limit=1. +@mock.patch.dict(os.environ, {"ROOT_PATH": "/custom/api/root"}) +def test_pagination_link_with_root_path(load_test_data): + """Test that pagination links are correct when ROOT_PATH is set.""" + # get_base_url directly uses os.getenv("ROOT_PATH"), so patching + # os.environ should be effective for new requests. + # The TestClient is initialized after the patch. + + # Use the global `app` imported from stac_fastapi.pgstac.app + # The TestClient for STAC FastAPI typically uses http://testserver as base + with TestClient(app, base_url="http://testserver") as client: + # Perform a search that should result in a 'next' link + # Assumes load_test_data has loaded more than 1 item + response = client.get("/search?limit=1") + response_json = response.json() + + assert response.status_code == 200, f"Response content: {response.text}" + next_link = None + for link in response_json.get("links", []): + if link.get("rel") == "next": + next_link = link + break + + assert next_link is not None, "Next link not found in response" + + href = next_link["href"] + + # Expected: http://testserver/custom/api/root/search?limit=1&token=next:... + # Not: http://testserver/custom/api/root/custom/api/root/search?... + + parsed_href = urlparse(href) + path = parsed_href.path + + # Check that the path starts correctly with the root path and the endpoint + expected_start_path = "/custom/api/root/search" + assert path.startswith(expected_start_path), f"Path {path} does not start with {expected_start_path}" + + # Check that the root path segment is not duplicated + # e.g. path should not be /custom/api/root/custom/api/root/search + duplicated_root_path = "/custom/api/root/custom/api/root/" + assert not path.startswith(duplicated_root_path), f"Path {path} shows duplicated root path starting with {duplicated_root_path}" + + # A more precise check for occurrences of the root path segments + # Path: /custom/api/root/search + # Root Path: /custom/api/root + # Effective path segments to check for duplication: custom/api/root + path_segments = path.strip('/').split('/') + root_path_segments_to_check = "custom/api/root".split('/') + + occurrences = 0 + for i in range(len(path_segments) - len(root_path_segments_to_check) + 1): + if path_segments[i:i+len(root_path_segments_to_check)] == root_path_segments_to_check: + occurrences += 1 + + assert occurrences == 1, f"Expected ROOT_PATH segments to appear once, but found {occurrences} times in path {path}. Segments: {path_segments}" + response = client.get(f"{prefix}/search", params={"limit": 1}) assert response.status_code == 200 assert response.json()["url"] == url_prefix + "/search?limit=1"