Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extensions to the HTTP protocol #3461

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions strawberry/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GraphQLRequestData:
query: Optional[str]
variables: Optional[Dict[str, Any]]
operation_name: Optional[str]
extensions: Optional[Dict[str, Any]]


def parse_query_params(params: Dict[str, str]) -> Dict[str, Any]:
Expand All @@ -47,4 +48,5 @@ def parse_request_data(data: Mapping[str, Any]) -> GraphQLRequestData:
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)
8 changes: 7 additions & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from strawberry.schema.base import BaseSchema
from strawberry.schema.exceptions import InvalidOperationTypeError
from strawberry.types import ExecutionResult
from strawberry.types.context_wrapper import ContextWrapper
from strawberry.types.graphql import OperationType

from .base import BaseView
Expand Down Expand Up @@ -102,11 +103,15 @@ async def execute_operation(

assert self.schema

context_wrapper = ContextWrapper(
context=context, extensions=request_data.extensions
)

return await self.schema.execute(
request_data.query,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
context_value=context_wrapper,
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
)
Expand Down Expand Up @@ -207,6 +212,7 @@ async def parse_http_body(
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)

async def process_result(
Expand Down
6 changes: 6 additions & 0 deletions strawberry/http/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def parse_query_params(self, params: QueryParams) -> Dict[str, Any]:
if variables:
params["variables"] = self.parse_json(variables)

if "extensions" in params:
extensions = params["extensions"]

if extensions:
params["extensions"] = self.parse_json(extensions)

return params

@property
Expand Down
8 changes: 7 additions & 1 deletion strawberry/http/sync_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from strawberry.schema import BaseSchema
from strawberry.schema.exceptions import InvalidOperationTypeError
from strawberry.types import ExecutionResult
from strawberry.types.context_wrapper import ContextWrapper
from strawberry.types.graphql import OperationType

from .base import BaseView
Expand Down Expand Up @@ -112,11 +113,15 @@ def execute_operation(

assert self.schema

context_wrapper = ContextWrapper(
context=context, extensions=request_data.extensions
)
Comment on lines +116 to +118
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this context wrapper only to pass request data?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly. Since the context data comes from get_context and that can be any type of object the user decides to return, this felt like a safe way to attach more data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to review and merge this this week 😊 thanks for the patience!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you so much! 🙏


return self.schema.execute_sync(
request_data.query,
root_value=root_value,
variable_values=request_data.variables,
context_value=context,
context_value=context_wrapper,
operation_name=request_data.operation_name,
allowed_operation_types=allowed_operation_types,
)
Expand Down Expand Up @@ -146,6 +151,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData
query=data.get("query"),
variables=data.get("variables"),
operation_name=data.get("operationName"),
extensions=data.get("extensions"),
)

def _handle_errors(
Expand Down
8 changes: 8 additions & 0 deletions strawberry/types/context_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional


@dataclass
class ContextWrapper:
context: Optional[Any]
extensions: Optional[Dict[str, Any]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion as above

Suggested change
extensions: Optional[Dict[str, Any]]
extensions: Optional[Dict[str, Any]] = None

9 changes: 9 additions & 0 deletions strawberry/types/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from typing_extensions import TypeVar

from .context_wrapper import ContextWrapper
from .nodes import convert_selections

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,8 +80,16 @@ def selected_fields(self) -> List[Selection]:

@property
def context(self) -> ContextType:
if isinstance(self._raw_info.context, ContextWrapper):
return self._raw_info.context.context
return self._raw_info.context

@property
def input_extensions(self) -> Dict[str, Any]:
if isinstance(self._raw_info.context, ContextWrapper):
return self._raw_info.context.extensions
return {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come this is called input_extensions? 😊

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, I just didn't have a better name 😅 I just felt extensions is used in so many contexts in strawberry and people might be confused as to which one this is. But I'm happy to change it to whatever you think makes most sense


@property
def root_value(self) -> RootValueType:
return self._raw_info.root_value
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,16 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
async with TestClient(TestServer(self.app)) as client:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body and files:
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if method == "get":
Expand Down
16 changes: 15 additions & 1 deletion tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response: ...

Expand Down Expand Up @@ -89,9 +90,15 @@ async def query(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
) -> Response:
return await self._graphql_request(
method, query=query, headers=headers, variables=variables, files=files
method,
query=query,
headers=headers,
variables=variables,
files=files,
extensions=extensions,
)

def _get_headers(
Expand All @@ -117,6 +124,7 @@ def _build_body(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
method: Literal["get", "post"] = "post",
extensions: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, object]]:
if query is None:
assert files is None
Expand All @@ -129,6 +137,9 @@ def _build_body(
if variables:
body["variables"] = variables

if extensions:
body["extensions"] = extensions

if files:
assert variables is not None

Expand All @@ -142,6 +153,9 @@ def _build_body(
if method == "get" and variables:
body["variables"] = json.dumps(variables)

if method == "get" and extensions:
body["extensions"] = json.dumps(extensions)

return body

@staticmethod
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/chalice.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
16 changes: 13 additions & 3 deletions tests/http/clients/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@


def generate_get_path(
path, query: str, variables: Optional[Dict[str, Any]] = None
path,
query: str,
variables: Optional[Dict[str, Any]] = None,
extensions: Optional[Dict[str, Any]] = None,
) -> str:
body: Dict[str, Any] = {"query": query}
if variables is not None:
body["variables"] = json_module.dumps(variables)
if extensions is not None:
body["extensions"] = json_module.dumps(extensions)

parts = [f"{k}={v}" for k, v in body.items()]
return f"{path}?{'&'.join(parts)}"
Expand Down Expand Up @@ -167,10 +172,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

headers = self._get_headers(method=method, headers=headers, files=files)
Expand All @@ -185,7 +195,7 @@ async def _graphql_request(
endpoint_url = "/graphql"
else:
body = b""
endpoint_url = generate_get_path("/graphql", query, variables)
endpoint_url = generate_get_path("/graphql", query, variables, extensions)

return await self.request(
url=endpoint_url, method=method, body=body, headers=headers
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
headers = self._get_headers(method=method, headers=headers, files=files)
additional_arguments = {**kwargs, **headers}

body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body:
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

data: Union[Dict[str, object], str, None] = None
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
if body := self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
):
if method == "get":
kwargs["params"] = body
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

url = "/graphql"
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/sanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body:
Expand Down
7 changes: 6 additions & 1 deletion tests/http/clients/starlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ async def _graphql_request(
variables: Optional[Dict[str, object]] = None,
files: Optional[Dict[str, BytesIO]] = None,
headers: Optional[Dict[str, str]] = None,
extensions: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Response:
body = self._build_body(
query=query, variables=variables, files=files, method=method
query=query,
variables=variables,
files=files,
method=method,
extensions=extensions,
)

if body:
Expand Down
Loading
Loading