Skip to content

Commit

Permalink
Merge pull request #5164 from opsmill/pog-generate-graphql-schema
Browse files Browse the repository at this point in the history
Generate GraphQL schema during startup
  • Loading branch information
ogenstad authored Dec 13, 2024
2 parents f35cc55 + 0d43df1 commit 0f28748
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 7 deletions.
11 changes: 11 additions & 0 deletions backend/infrahub/core/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from infrahub.core.schema.manager import SchemaManager
from infrahub.database import InfrahubDatabase
from infrahub.exceptions import DatabaseError
from infrahub.graphql.manager import GraphQLSchemaManager
from infrahub.log import get_logger
from infrahub.menu.menu import default_menu
from infrahub.menu.utils import create_menu_children
Expand Down Expand Up @@ -178,6 +179,16 @@ async def initialization(db: InfrahubDatabase) -> None:
)
await branch.save(db=db)

default_branch = registry.get_branch_from_registry(branch=registry.default_branch)
schema_branch = registry.schema.get_schema_branch(name=default_branch.name)
gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=default_branch, schema_branch=schema_branch)
gqlm.get_graphql_schema(
include_query=True,
include_mutation=True,
include_subscription=True,
include_types=True,
)

# ---------------------------------------------------
# Load Default Namespace
# ---------------------------------------------------
Expand Down
17 changes: 17 additions & 0 deletions backend/infrahub/tasks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async def refresh_branches(db: InfrahubDatabase) -> None:
If a branch is already present with a different value for the hash
We pull the new schema from the database and we update the registry.
"""
from infrahub.graphql.manager import GraphQLSchemaManager # pylint: disable=import-outside-toplevel,cyclic-import

async with lock.registry.local_schema_lock():
branches = await registry.branch_object.get_list(db=db)
Expand All @@ -38,11 +39,27 @@ async def refresh_branches(db: InfrahubDatabase) -> None:
)
await registry.schema.load_schema(db=db, branch=new_branch)
registry.branch[new_branch.name] = new_branch
schema_branch = registry.schema.get_schema_branch(name=new_branch.name)
gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=new_branch, schema_branch=schema_branch)
gqlm.get_graphql_schema(
include_query=True,
include_mutation=True,
include_subscription=True,
include_types=True,
)

else:
log.info("New branch detected, pulling schema", branch=new_branch.name, worker=WORKER_IDENTITY)
await registry.schema.load_schema(db=db, branch=new_branch)
registry.branch[new_branch.name] = new_branch
schema_branch = registry.schema.get_schema_branch(name=new_branch.name)
gqlm = GraphQLSchemaManager.get_manager_for_branch(branch=new_branch, schema_branch=schema_branch)
gqlm.get_graphql_schema(
include_query=True,
include_mutation=True,
include_subscription=True,
include_types=True,
)

for branch_name in list(registry.branch.keys()):
if branch_name not in active_branches:
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/helpers/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from infrahub.core import registry
from infrahub.core.schema import SchemaRoot
from infrahub.graphql.manager import GraphQLSchemaManager

from .car import CAR
from .manufacturer import MANUFACTURER
Expand All @@ -28,6 +29,7 @@ async def load_schema(db: InfrahubDatabase, schema: SchemaRoot, branch_name: str
await registry.schema.update_schema_branch(
schema=tmp_schema, db=db, branch=branch_name or default_branch_name, update_db=True
)
GraphQLSchemaManager.clear_cache()


__all__ = ["CAR", "CAR_SCHEMA", "MANUFACTURER", "PERSON", "TICKET", "WIDGET"]
4 changes: 3 additions & 1 deletion backend/tests/unit/api/test_50_config_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from infrahub.database import InfrahubDatabase


async def test_config_endpoint(db: InfrahubDatabase, client, client_headers, default_branch):
async def test_config_endpoint(
db: InfrahubDatabase, client, client_headers, default_branch, register_core_models_schema: None
):
with client:
response = client.get(
"/api/config",
Expand Down
10 changes: 8 additions & 2 deletions backend/tests/unit/api/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jwt
from fastapi.testclient import TestClient

from infrahub import config
from infrahub.core.branch import Branch
from infrahub.database import InfrahubDatabase

EXPIRED_ACCESS_TOKEN = (
Expand Down Expand Up @@ -163,7 +165,9 @@ async def test_password_based_login_invalid_password(db: InfrahubDatabase, defau
}


async def test_use_expired_token(db: InfrahubDatabase, default_branch, client):
async def test_use_expired_token(
db: InfrahubDatabase, default_branch: Branch, client: TestClient, register_core_models_schema: None
) -> None:
with client:
response = client.get(
"/api/transform/jinja2/testing", headers={"Authorization": f"Bearer {EXPIRED_ACCESS_TOKEN}"}
Expand All @@ -173,7 +177,9 @@ async def test_use_expired_token(db: InfrahubDatabase, default_branch, client):
assert response.json() == {"data": None, "errors": [{"message": "Expired Signature", "extensions": {"code": 401}}]}


async def test_refresh_access_token_with_expired_refresh_token(db: InfrahubDatabase, default_branch, client):
async def test_refresh_access_token_with_expired_refresh_token(
db: InfrahubDatabase, default_branch: Branch, client: TestClient, register_core_models_schema: None
) -> None:
"""Validate that the correct error is returned for an expired refresh token"""
with client:
response = client.post("/api/auth/refresh", headers={"Authorization": f"Bearer {EXPIRED_REFRESH_TOKEN}"})
Expand Down
7 changes: 3 additions & 4 deletions backend/tests/unit/api/test_openapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from fastapi.testclient import TestClient

from infrahub.core.branch import Branch


async def test_openapi(
client,
default_branch: Branch,
):
async def test_openapi(client: TestClient, default_branch: Branch, register_core_models_schema: None) -> None:
"""Validate that the OpenAPI specs can be generated."""
with client:
response = client.get(
Expand Down

0 comments on commit 0f28748

Please sign in to comment.