Skip to content

Commit

Permalink
Change to app factory pattern (#224)
Browse files Browse the repository at this point in the history
* Change to app factory pattern. The factory pattern is supported by uvicorn, gunicorn, and FastAPI. It allows parameterised creation of the App object among other benefits.

* update Dockerfile and prez tests to use factory pattern
  • Loading branch information
ashleysommer authored May 1, 2024
1 parent 8b979db commit 45e8426
Show file tree
Hide file tree
Showing 17 changed files with 171 additions and 94 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ WORKDIR /app
# prez module is already built as a package and installed in $VIRTUAL_ENV as a library
COPY main.py pyproject.toml ./

ENTRYPOINT uvicorn prez.app:app --host=${HOST:-0.0.0.0} --port=${PORT:-8000} --proxy-headers
ENTRYPOINT uvicorn prez.app:assemble_app --factory --host=${HOST:-0.0.0.0} --port=${PORT:-8000} --proxy-headers
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

port = int(environ.get("PREZ_DEV_SERVER_PORT", 8000))

uvicorn.run("prez.app:app", port=port, reload=True)
uvicorn.run("prez.app:assemble_app", factory=True, port=port, reload=True)
193 changes: 118 additions & 75 deletions prez/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import os
from functools import partial
from textwrap import dedent

from typing import Optional, Dict, Union, Any
import uvicorn
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from rdflib import Graph
from starlette.middleware.cors import CORSMiddleware

from prez.config import settings
from prez.config import settings, Settings
from prez.dependencies import (
get_async_http_client,
get_pyoxi_store,
Expand Down Expand Up @@ -53,34 +53,28 @@
from prez.services.search_methods import get_all_search_methods
from prez.sparql.methods import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo

app = FastAPI(
exception_handlers={
400: catch_400,
404: catch_404,
500: catch_500,
ClassNotFoundException: catch_class_not_found_exception,
URINotFoundException: catch_uri_not_found_exception,
NoProfilesException: catch_no_profiles_exception,
}
)


app.include_router(cql_router)
app.include_router(management_router)
app.include_router(object_router)
app.include_router(sparql_router)
app.include_router(search_router)
app.include_router(profiles_router)
if "CatPrez" in settings.prez_flavours:
app.include_router(catprez_router)
if "VocPrez" in settings.prez_flavours:
app.include_router(vocprez_router)
if "SpacePrez" in settings.prez_flavours:
app.include_router(spaceprez_router)
app.include_router(identifier_router)
def prez_open_api_metadata(
title: str,
description: str,
version: str,
contact: Optional[Dict[str, Union[str, Any]]],
_app: FastAPI,
):
return get_openapi(
title=title,
description=description,
version=version,
contact=contact,
license_info=_app.license_info,
openapi_version=_app.openapi_version,
terms_of_service=_app.terms_of_service,
tags=_app.openapi_tags,
servers=_app.servers,
routes=_app.routes,
)


@app.middleware("http")
async def add_cors_headers(request, call_next):
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = "*"
Expand All @@ -90,69 +84,45 @@ async def add_cors_headers(request, call_next):
return response


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)


def prez_open_api_metadata():
return get_openapi(
title=settings.prez_title,
version=settings.prez_version,
description=settings.prez_desc,
routes=app.routes,
)


app.openapi = prez_open_api_metadata


@app.on_event("startup")
async def app_startup():
async def app_startup(_settings: Settings, _app: FastAPI):
"""
This function runs at startup and will continually poll the separate backends until their SPARQL endpoints
are available. Initial caching can be triggered within the try block. NB this function does not check that data is
appropriately configured at the SPARQL endpoint(s), only that the SPARQL endpoint(s) are reachable.
"""
setup_logger(settings)
setup_logger(_settings)
log = logging.getLogger("prez")
log.info("Starting up")

if settings.sparql_repo_type == "pyoxigraph":
app.state.pyoxi_store = get_pyoxi_store()
app.state.repo = PyoxigraphRepo(app.state.pyoxi_store)
await load_local_data_to_oxigraph(app.state.pyoxi_store)
elif settings.sparql_repo_type == "oxrdflib":
app.state.oxrdflib_store = get_oxrdflib_store()
app.state.repo = OxrdflibRepo(app.state.oxrdflib_store)
elif settings.sparql_repo_type == "remote":
app.state.http_async_client = await get_async_http_client()
app.state.repo = RemoteSparqlRepo(app.state.http_async_client)
if _settings.sparql_repo_type == "pyoxigraph":
_app.state.pyoxi_store = get_pyoxi_store()
_app.state.repo = _repo = PyoxigraphRepo(_app.state.pyoxi_store)
await load_local_data_to_oxigraph(_app.state.pyoxi_store)
elif _settings.sparql_repo_type == "oxrdflib":
_app.state.oxrdflib_store = get_oxrdflib_store()
_app.state.repo = _repo = OxrdflibRepo(_app.state.oxrdflib_store)
elif _settings.sparql_repo_type == "remote":
_app.state.http_async_client = await get_async_http_client()
_app.state.repo = _repo = RemoteSparqlRepo(_app.state.http_async_client)
await healthcheck_sparql_endpoints()
else:
raise ValueError(
"SPARQL_REPO_TYPE must be one of 'pyoxigraph', 'oxrdflib' or 'remote'"
)

await add_prefixes_to_prefix_graph(app.state.repo)
await get_all_search_methods(app.state.repo)
await create_profiles_graph(app.state.repo)
await create_endpoints_graph(app.state.repo)
await count_objects(app.state.repo)
await add_prefixes_to_prefix_graph(_repo)
await get_all_search_methods(_repo)
await create_profiles_graph(_repo)
await create_endpoints_graph(_repo)
await count_objects(_repo)
await populate_api_info()
await add_common_context_ontologies_to_tbox_cache()

app.state.pyoxi_system_store = get_system_store()
await load_system_data_to_oxigraph(app.state.pyoxi_system_store)
_app.state.pyoxi_system_store = get_system_store()
await load_system_data_to_oxigraph(_app.state.pyoxi_system_store)


@app.on_event("shutdown")
async def app_shutdown():
async def app_shutdown(_settings: Settings, _app: FastAPI):
"""
persists caches
close async SPARQL clients
Expand All @@ -161,8 +131,81 @@ async def app_shutdown():
log.info("Shutting down...")

# close all SPARQL async clients
if not settings.sparql_repo_type:
await app.state.http_async_client.aclose()
if not _settings.sparql_repo_type:
await _app.state.http_async_client.aclose()


def assemble_app(
root_path: str = "",
title: Optional[str] = None,
description: Optional[str] = None,
version: Optional[str] = None,
local_settings: Optional[Settings] = None,
**kwargs
):

_settings = local_settings if local_settings is not None else settings

if title is None:
title = _settings.prez_title
if version is None:
version = _settings.prez_version
if description is None:
description = _settings.prez_desc
contact = _settings.prez_contact

app = FastAPI(
root_path=root_path,
title=title,
version=version,
description=description,
contact=contact,
exception_handlers={
400: catch_400,
404: catch_404,
500: catch_500,
ClassNotFoundException: catch_class_not_found_exception,
URINotFoundException: catch_uri_not_found_exception,
NoProfilesException: catch_no_profiles_exception,
},
**kwargs
)

app.include_router(cql_router)
app.include_router(management_router)
app.include_router(object_router)
app.include_router(sparql_router)
app.include_router(search_router)
app.include_router(profiles_router)
if "CatPrez" in _settings.prez_flavours:
app.include_router(catprez_router)
if "VocPrez" in _settings.prez_flavours:
app.include_router(vocprez_router)
if "SpacePrez" in _settings.prez_flavours:
app.include_router(spaceprez_router)
app.include_router(identifier_router)
app.openapi = partial(
prez_open_api_metadata,
title=title,
description=description,
version=version,
contact=contact,
_app=app,
)

app.middleware("http")(add_cors_headers)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
app.on_event("startup")(partial(app_startup, _settings=_settings, _app=app))
app.on_event("shutdown")(partial(app_shutdown, _settings=_settings, _app=app))
return app


def _get_sparql_service_description(request, format):
Expand Down Expand Up @@ -200,4 +243,4 @@ def _get_sparql_service_description(request, format):


if __name__ == "__main__":
uvicorn.run("app:app", port=settings.port, host=settings.host)
uvicorn.run(assemble_app, factory=True, port=settings.port, host=settings.host)
3 changes: 2 additions & 1 deletion prez/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from os import environ
from pathlib import Path
from typing import Optional
from typing import Optional, Union, Any, Dict

import toml
from pydantic import BaseSettings, root_validator
Expand Down Expand Up @@ -58,6 +58,7 @@ class Settings(BaseSettings):
"Knowledge Graph data which can be subset according to information profiles."
)
prez_version: Optional[str]
prez_contact: Optional[Dict[str, Union[str, Any]]]
disable_prefix_generation: bool = False
local_rdf_dir: str = "rdf"

Expand Down
4 changes: 3 additions & 1 deletion tests/test_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.testclient import TestClient
from pyoxigraph.pyoxigraph import Store

from prez.app import app
from prez.app import assemble_app
from prez.dependencies import get_repo
from prez.sparql.methods import Repo, PyoxigraphRepo

Expand Down Expand Up @@ -32,6 +32,8 @@ def test_client(test_repo: Repo) -> TestClient:
def override_get_repo():
return test_repo

app = assemble_app()

app.dependency_overrides[get_repo] = override_get_repo

with TestClient(app) as c:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_curie_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import pytest
from fastapi.testclient import TestClient

from prez.app import app
from prez.app import assemble_app


@pytest.fixture
def client() -> TestClient:
app = assemble_app()
testclient = TestClient(app)

# Make a request for the following IRI to ensure
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dd_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.testclient import TestClient
from pyoxigraph.pyoxigraph import Store

from prez.app import app
from prez.app import assemble_app
from prez.dependencies import get_repo
from prez.sparql.methods import Repo, PyoxigraphRepo

Expand Down Expand Up @@ -34,6 +34,8 @@ def test_client(test_repo: Repo) -> TestClient:
def override_get_repo():
return test_repo

app = assemble_app()

app.dependency_overrides[get_repo] = override_get_repo

with TestClient(app) as c:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_endpoints_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyoxigraph.pyoxigraph import Store
from rdflib import Graph

from prez.app import app
from prez.app import assemble_app
from prez.dependencies import get_repo
from prez.sparql.methods import Repo, PyoxigraphRepo

Expand Down Expand Up @@ -33,6 +33,8 @@ def test_client(test_repo: Repo) -> TestClient:
def override_get_repo():
return test_repo

app = assemble_app()

app.dependency_overrides[get_repo] = override_get_repo

with TestClient(app) as c:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_endpoints_catprez.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rdflib import Graph, URIRef
from rdflib.namespace import RDF, DCAT

from prez.app import app
from prez.app import assemble_app
from prez.dependencies import get_repo
from prez.sparql.methods import Repo, PyoxigraphRepo

Expand Down Expand Up @@ -34,6 +34,8 @@ def client(test_repo: Repo) -> TestClient:
def override_get_repo():
return test_repo

app = assemble_app()

app.dependency_overrides[get_repo] = override_get_repo

with TestClient(app) as c:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_endpoints_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyoxigraph.pyoxigraph import Store
from rdflib import Graph

from prez.app import app
from prez.app import assemble_app
from prez.dependencies import get_repo
from prez.reference_data.prez_ns import PREZ
from prez.sparql.methods import Repo, PyoxigraphRepo
Expand Down Expand Up @@ -34,6 +34,8 @@ def client(test_repo: Repo) -> TestClient:
def override_get_repo():
return test_repo

app = assemble_app()

app.dependency_overrides[get_repo] = override_get_repo

with TestClient(app) as c:
Expand Down
Loading

0 comments on commit 45e8426

Please sign in to comment.