Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt graphql backend #71

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies = [
"pvi~=0.10.0",
"pytango",
"softioc",
"fastapi[standard]",
"strawberry-graphql[fastapi]",
]
dynamic = ["version"]
license.file = "LICENSE"
Expand All @@ -43,6 +45,7 @@ dev = [
"types-mock",
"aioca",
"p4p",
"httpx",
]

[project.scripts]
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions src/fastcs/backends/graphQL/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastcs.backend import Backend
from fastcs.controller import Controller

from .graphQL import GraphQLServer


class GraphQLBackend(Backend):
def __init__(self, controller: Controller):
super().__init__(controller)

self._server = GraphQLServer(self._mapping)

def _run(self):
self._server.run()
226 changes: 226 additions & 0 deletions src/fastcs/backends/graphQL/graphQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from collections.abc import Awaitable, Callable, Coroutine
from dataclasses import dataclass
from typing import Any

import strawberry
import uvicorn
from fastapi import FastAPI
from strawberry.asgi import GraphQL
from strawberry.tools import create_type
from strawberry.types.field import StrawberryField

from fastcs.attributes import AttrR, AttrRW, AttrW, T
from fastcs.controller import BaseController
from fastcs.mapping import Mapping


@dataclass
class GraphQLServerOptions:
host: str = "localhost"
port: int = 8080
log_level: str = "info"


class GraphQLServer:
def __init__(self, mapping: Mapping):
self._mapping = mapping
self._fields_tree: FieldsTree = FieldsTree("")
self._app = self._create_app()

def _create_app(self) -> FastAPI:
_add_dev_attributes(self._fields_tree, self._mapping)
_add_dev_commands(self._fields_tree, self._mapping)

schema_kwargs = {}
for key in ["query", "mutation"]:
if s_type := self._fields_tree.create_type(key):
schema_kwargs[key] = s_type
schema = strawberry.Schema(**schema_kwargs) # type: ignore
graphql_app: GraphQL = GraphQL(schema)

app = FastAPI()
app.add_route("/graphql", graphql_app) # type: ignore
app.add_websocket_route("/graphql", graphql_app) # type: ignore

return app

def run(self, options: GraphQLServerOptions | None = None) -> None:
if options is None:
options = GraphQLServerOptions()

uvicorn.run(
self._app,
host=options.host,
port=options.port,
log_level=options.log_level,
)


def _wrap_attr_set(
d_attr_name: str,
attribute: AttrW[T],
) -> Callable[[T], Coroutine[Any, Any, None]]:
async def _dynamic_f(value):
await attribute.process(value)
return value

# Add type annotations for validation, schema, conversions
_dynamic_f.__name__ = d_attr_name
_dynamic_f.__annotations__["value"] = attribute.datatype.dtype
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype

return _dynamic_f


def _wrap_attr_get(
d_attr_name: str,
attribute: AttrR[T],
) -> Callable[[], Coroutine[Any, Any, Any]]:
async def _dynamic_f() -> Any:
return attribute.get()

_dynamic_f.__name__ = d_attr_name
_dynamic_f.__annotations__["return"] = attribute.datatype.dtype

return _dynamic_f


def _wrap_as_field(
field_name: str,
strawberry_type: type,
) -> StrawberryField:
def _dynamic_field():
return strawberry_type()

_dynamic_field.__name__ = field_name
_dynamic_field.__annotations__["return"] = strawberry_type

return strawberry.field(_dynamic_field)


class NodeNotFoundError(Exception):
pass


class FieldsTree:
def __init__(self, name: str):
self.name = name
self.children: list[FieldsTree] = []
self.fields_dict: dict[str, list[StrawberryField]] = {
"query": [],
"mutation": [],
}

def insert(self, path: list[str]) -> "FieldsTree":
# Create child if not exist
name = path.pop(0)
if self.is_child(name):
child = self.get_child(name)
else:
child = FieldsTree(name)
self.children.append(child)

# Recurse if needed
if path:
return child.insert(path) # type: ignore
else:
return child

def is_child(self, name: str) -> bool:
for child in self.children:
if child.name == name:
return True
return False

def get_child(self, name: str) -> "FieldsTree":
for child in self.children:
if child.name == name:
return child
raise NodeNotFoundError

def create_type(self, strawberry_type: str) -> type | None:
for child in self.children:
if new_type := child.create_type(strawberry_type):
child_field = _wrap_as_field(
child.name,
new_type,
)
self.fields_dict[strawberry_type].append(child_field)

if self.fields_dict[strawberry_type]:
return create_type(
f"{self.name}{strawberry_type}", self.fields_dict[strawberry_type]
)
else:
return None


def _add_dev_attributes(
fields_tree: FieldsTree,
mapping: Mapping,
) -> None:
for single_mapping in mapping.get_controller_mappings():
path = single_mapping.controller.path
if path:
node = fields_tree.insert(path)
else:
node = fields_tree

if node is not None:
for attr_name, attribute in single_mapping.attributes.items():
attr_name = attr_name.title().replace("_", "")

match attribute:
# mutation for server changes https://graphql.org/learn/queries/
case AttrRW():
node.fields_dict["query"].append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
node.fields_dict["mutation"].append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)
case AttrR():
node.fields_dict["query"].append(
strawberry.field(_wrap_attr_get(attr_name, attribute))
)
case AttrW():
node.fields_dict["mutation"].append(
strawberry.mutation(_wrap_attr_set(attr_name, attribute))
)


def _wrap_command(
method_name: str, method: Callable, controller: BaseController
) -> Callable[..., Awaitable[bool]]:
async def _dynamic_f() -> bool:
await getattr(controller, method.__name__)()
return True

_dynamic_f.__name__ = method_name

return _dynamic_f


def _add_dev_commands(
fields_tree: FieldsTree,
mapping: Mapping,
) -> None:
for single_mapping in mapping.get_controller_mappings():
path = single_mapping.controller.path
if path:
node = fields_tree.insert(path)
else:
node = fields_tree

if node is not None:
for name, method in single_mapping.command_methods.items():
cmd_name = name.title().replace("_", "")
node.fields_dict["mutation"].append(
strawberry.mutation(
_wrap_command(
cmd_name,
method.fn,
single_mapping.controller,
)
)
)
126 changes: 126 additions & 0 deletions tests/backends/graphQL/test_graphQL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import copy
import json
import re
from typing import Any

import pytest
from fastapi.testclient import TestClient

from fastcs.attributes import AttrR
from fastcs.backends.graphQL.backend import GraphQLBackend
from fastcs.datatypes import Bool, Float, Int


def pascal_2_snake(input: list[str]) -> list[str]:
snake_list = copy.deepcopy(input)
snake_list[-1] = re.sub(r"(?<!^)(?=[A-Z])", "_", snake_list[-1]).lower()
return snake_list


def nest_query(path: list[str]) -> str:
queue = copy.deepcopy(path)
field = queue.pop(0)

if queue:
nesting = nest_query(queue)
return f"{field} {{ {nesting} }} "
else:
return field


def nest_mutation(path: list[str], value: Any) -> str:
queue = copy.deepcopy(path)
field = queue.pop(0)

if queue:
nesting = nest_query(queue)
return f"{field} {{ {nesting} }} "
else:
return f"{field}(value: {json.dumps(value)})"


def nest_responce(path: list[str], value: Any) -> dict:
queue = copy.deepcopy(path)
field = queue.pop(0)

if queue:
nesting = nest_responce(queue, value)
return {field: nesting}
else:
return {field: value}


class TestGraphQLServer:
@pytest.fixture(scope="class", autouse=True)
def setup_class(self, assertable_controller):
self.controller = assertable_controller

@pytest.fixture(scope="class")
def client(self):
app = GraphQLBackend(self.controller)._server._app
return TestClient(app)

@pytest.fixture(scope="class")
def client_read(self, client):
def _client_read(path: list[str], expected: Any):
query = f"query {{ {nest_query(path)} }}"
with self.controller.assertPerformed(pascal_2_snake(path), "READ"):
response = client.post("/graphql", json={"query": query})
assert response.status_code == 200
assert response.json()["data"] == nest_responce(path, expected)

return _client_read

@pytest.fixture(scope="class")
def client_write(self, client):
def _client_write(path: list[str], value: Any):
mutation = f"mutation {{ {nest_mutation(path, value)} }}"
with self.controller.assertPerformed(pascal_2_snake(path), "WRITE"):
response = client.post("/graphql", json={"query": mutation})
assert response.status_code == 200
assert response.json()["data"] == nest_responce(path, value)

return _client_write

@pytest.fixture(scope="class")
def client_exec(self, client):
def _client_exec(path: list[str]):
mutation = f"mutation {{ {nest_query(path)} }}"
with self.controller.assertPerformed(pascal_2_snake(path), "EXECUTE"):
response = client.post("/graphql", json={"query": mutation})
assert response.status_code == 200
assert response.json()["data"] == {path[-1]: True}

return _client_exec

def test_read_int(self, client_read):
client_read(["ReadInt"], AttrR(Int())._value)

def test_read_write_int(self, client_read, client_write):
client_read(["ReadWriteInt"], AttrR(Int())._value)
client_write(["ReadWriteInt"], AttrR(Int())._value)

def test_read_write_float(self, client_read, client_write):
client_read(["ReadWriteFloat"], AttrR(Float())._value)
client_write(["ReadWriteFloat"], AttrR(Float())._value)

def test_read_bool(self, client_read):
client_read(["ReadBool"], AttrR(Bool())._value)

def test_write_bool(self, client_write):
client_write(["WriteBool"], AttrR(Bool())._value)

# # We need to discuss enums
# def test_string_enum(self, client_read, client_write):

def test_big_enum(self, client_read):
client_read(["BigEnum"], AttrR(Int(), allowed_values=list(range(1, 18)))._value)

def test_go(self, client_exec):
client_exec(["Go"])

def test_read_child1(self, client_read):
client_read(["SubController01", "ReadInt"], AttrR(Int())._value)

def test_read_child2(self, client_read):
client_read(["SubController02", "ReadInt"], AttrR(Int())._value)
Loading
Loading