From af9e3966a1755a85e031e1c41afb32e85c0581d1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 01:56:40 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- integration_tests/base_routes.py | 84 +++++-------- integration_tests/conftest.py | 5 +- integration_tests/test_streaming_responses.py | 19 +-- robyn/__init__.py | 119 +++++++----------- robyn/responses.py | 2 +- robyn/robyn.pyi | 2 +- robyn/router.py | 3 +- 7 files changed, 91 insertions(+), 143 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 46d52ff62..5ccaa2fef 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -750,14 +750,9 @@ def handle_exception(error): # Create a response with proper error handling response = Response( status_code=500, - headers=Headers({ - "Content-Type": "text/plain", - "X-Error-Response": "true", - "global_after": "global_after_request", - "server": "robyn" - }), + headers=Headers({"Content-Type": "text/plain", "X-Error-Response": "true", "global_after": "global_after_request", "server": "robyn"}), description=f"error msg: {error}".encode(), - streaming=False + streaming=False, ) return response @@ -1103,30 +1098,25 @@ def create_item(request, body: CreateItemBody, query: CreateItemQueryParamsParam # --- Streaming responses --- + @app.get("/stream/sync", streaming=True) async def sync_stream(): def generator(): for i in range(5): yield f"Chunk {i}\n".encode() - + headers = Headers({"Content-Type": "text/plain"}) - return Response( - status_code=200, - description=generator(), - headers=headers - ) + return Response(status_code=200, description=generator(), headers=headers) + @app.get("/stream/async", streaming=True) async def async_stream(): async def generator(): for i in range(5): yield f"Async Chunk {i}\n".encode() - - return Response( - status_code=200, - headers={"Content-Type": "text/plain"}, - description=generator() - ) + + return Response(status_code=200, headers={"Content-Type": "text/plain"}, description=generator()) + @app.get("/stream/mixed", streaming=True) async def mixed_stream(): @@ -1135,12 +1125,9 @@ async def generator(): yield "String chunk\n".encode() yield str(42).encode() + b"\n" yield json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n" - - return Response( - status_code=200, - headers={"Content-Type": "text/plain"}, - description=generator() - ) + + return Response(status_code=200, headers={"Content-Type": "text/plain"}, description=generator()) + @app.get("/stream/events", streaming=True) async def server_sent_events(): @@ -1148,73 +1135,60 @@ async def event_generator(): import asyncio import json import time - + # Regular event yield f"event: message\ndata: {json.dumps({'time': time.time(), 'type': 'start'})}\n\n".encode() await asyncio.sleep(1) - + # Event with ID yield f"id: 1\nevent: update\ndata: {json.dumps({'progress': 50})}\n\n".encode() await asyncio.sleep(1) - + # Multiple data lines - data = json.dumps({'status': 'complete', 'results': [1, 2, 3]}, indent=2) + data = json.dumps({"status": "complete", "results": [1, 2, 3]}, indent=2) yield f"event: complete\ndata: {data}\n\n".encode() - + return Response( - status_code=200, - headers={ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive" - }, - description=event_generator() + status_code=200, headers={"Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive"}, description=event_generator() ) + @app.get("/stream/large-file", streaming=True) async def stream_large_file(): async def file_generator(): # Simulate streaming a large file in chunks chunk_size = 1024 # 1KB chunks total_size = 10 * chunk_size # 10KB total - + for offset in range(0, total_size, chunk_size): # Simulate reading file chunk chunk = b"X" * min(chunk_size, total_size - offset) yield chunk - + return Response( status_code=200, - headers={ - "Content-Type": "application/octet-stream", - "Content-Disposition": "attachment; filename=large-file.bin" - }, - description=file_generator() + headers={"Content-Type": "application/octet-stream", "Content-Disposition": "attachment; filename=large-file.bin"}, + description=file_generator(), ) + @app.get("/stream/csv", streaming=True) async def stream_csv(): async def csv_generator(): # CSV header yield "id,name,value\n".encode() - + import asyncio import random - + # Generate rows for i in range(5): await asyncio.sleep(0.5) # Simulate data processing row = f"{i},item-{i},{random.randint(1, 100)}\n" yield row.encode() - - return Response( - status_code=200, - headers={ - "Content-Type": "text/csv", - "Content-Disposition": "attachment; filename=data.csv" - }, - description=csv_generator() - ) + + return Response(status_code=200, headers={"Content-Type": "text/csv", "Content-Disposition": "attachment; filename=data.csv"}, description=csv_generator()) + def main(): app.set_response_header("server", "robyn") diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index c450919df..77c661e0b 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -8,12 +8,10 @@ from typing import List import pytest -import pytest_asyncio -from robyn import Robyn -from integration_tests.base_routes import app from integration_tests.helpers.network_helpers import get_network_host + def spawn_process(command: List[str]) -> subprocess.Popen: if platform.system() == "Windows": command[0] = "python" @@ -129,4 +127,3 @@ def env_file(): env_path.unlink() del os.environ["ROBYN_PORT"] del os.environ["ROBYN_HOST"] - diff --git a/integration_tests/test_streaming_responses.py b/integration_tests/test_streaming_responses.py index 4e16f9504..be5d8d0a3 100644 --- a/integration_tests/test_streaming_responses.py +++ b/integration_tests/test_streaming_responses.py @@ -19,6 +19,7 @@ # Mark all tests in this module as async pytestmark = pytest.mark.asyncio + async def test_sync_stream(): """Test basic synchronous streaming response.""" async with aiohttp.ClientSession() as client: @@ -34,6 +35,7 @@ async def test_sync_stream(): for i, chunk in enumerate(chunks): assert chunk == f"Chunk {i}\n" + async def test_async_stream(): """Test asynchronous streaming response.""" async with aiohttp.ClientSession() as client: @@ -49,6 +51,7 @@ async def test_async_stream(): for i, chunk in enumerate(chunks): assert chunk == f"Async Chunk {i}\n" + async def test_mixed_stream(): """Test streaming of mixed content types.""" async with aiohttp.ClientSession() as client: @@ -56,12 +59,7 @@ async def test_mixed_stream(): assert response.status == 200 assert response.headers["Content-Type"] == "text/plain" - expected = [ - b"Binary chunk\n", - b"String chunk\n", - b"42\n", - json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n" - ] + expected = [b"Binary chunk\n", b"String chunk\n", b"42\n", json.dumps({"message": "JSON chunk", "number": 123}).encode() + b"\n"] chunks = [] async for chunk in response.content: @@ -71,6 +69,7 @@ async def test_mixed_stream(): for chunk, expected_chunk in zip(chunks, expected): assert chunk == expected_chunk + async def test_server_sent_events(): """Test Server-Sent Events (SSE) streaming.""" async with aiohttp.ClientSession() as client: @@ -103,6 +102,7 @@ async def test_server_sent_events(): assert event_data["status"] == "complete" assert event_data["results"] == [1, 2, 3] + async def test_large_file_stream(): """Test streaming of large files in chunks.""" async with aiohttp.ClientSession() as client: @@ -118,6 +118,7 @@ async def test_large_file_stream(): assert total_size == 10 * 1024 # 10KB total + async def test_csv_stream(): """Test streaming of CSV data.""" async with aiohttp.ClientSession() as client: @@ -132,11 +133,11 @@ async def test_csv_stream(): # Verify header assert lines[0] == "id,name,value" - + # Verify data rows assert len(lines) == 6 # Header + 5 data rows for i, line in enumerate(lines[1:], 0): - id_, name, value = line.split(',') + id_, name, value = line.split(",") assert int(id_) == i assert name == f"item-{i}" - assert 1 <= int(value) <= 100 \ No newline at end of file + assert 1 <= int(value) <= 100 diff --git a/robyn/__init__.py b/robyn/__init__.py index f2c5369b2..b694a44f7 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -338,14 +338,7 @@ def get_functions(view) -> List[Tuple[HttpMethod, Callable]]: handlers = get_functions(view) for route_type, handler in handlers: - self.add_route( - route_type=route_type, - endpoint=endpoint, - handler=handler, - is_const=const, - streaming=False, - auth_required=False - ) + self.add_route(route_type=route_type, endpoint=endpoint, handler=handler, is_const=const, streaming=False, auth_required=False) def view(self, endpoint: str, const: bool = False): """ @@ -383,12 +376,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("get", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.GET, - endpoint=endpoint, - handler=handler, - is_const=const, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.GET, endpoint=endpoint, handler=handler, is_const=const, streaming=streaming, auth_required=auth_required ) return inner @@ -415,12 +403,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("post", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.POST, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.POST, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -447,12 +430,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("put", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.PUT, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.PUT, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -479,12 +457,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("delete", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.DELETE, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.DELETE, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -511,12 +484,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("patch", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.PATCH, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.PATCH, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -543,12 +511,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("head", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.HEAD, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.HEAD, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -575,12 +538,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("options", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.OPTIONS, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.OPTIONS, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -606,12 +564,7 @@ def connect( def inner(handler): self.openapi.add_openapi_path_obj("connect", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.CONNECT, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.CONNECT, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -638,12 +591,7 @@ def inner(handler): self.openapi.add_openapi_path_obj("trace", endpoint, openapi_name, openapi_tags, handler) return self.add_route( - route_type=HttpMethod.TRACE, - endpoint=endpoint, - handler=handler, - is_const=False, - streaming=streaming, - auth_required=auth_required + route_type=HttpMethod.TRACE, endpoint=endpoint, handler=handler, is_const=False, streaming=streaming, auth_required=auth_required ) return inner @@ -686,29 +634,58 @@ def __init__(self, file_object: str, prefix: str = "", config: Config = Config() def __add_prefix(self, endpoint: str): return f"{self.prefix}{endpoint}" - def get(self, endpoint: str, const: bool = False, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["get"]): - return super().get(endpoint=self.__add_prefix(endpoint), const=const, streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def get( + self, + endpoint: str, + const: bool = False, + streaming: bool = False, + auth_required: bool = False, + openapi_name: str = "", + openapi_tags: List[str] = ["get"], + ): + return super().get( + endpoint=self.__add_prefix(endpoint), + const=const, + streaming=streaming, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + ) def post(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"]): - return super().post(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().post( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def put(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"]): - return super().put(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().put( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def delete(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"]): - return super().delete(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().delete( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def patch(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"]): - return super().patch(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().patch( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def head(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"]): - return super().head(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().head( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def trace(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"]): - return super().trace(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().trace( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def options(self, endpoint: str, streaming: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["options"]): - return super().options(endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return super().options( + endpoint=self.__add_prefix(endpoint), streaming=streaming, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags + ) def ALLOW_CORS(app: Robyn, origins: Union[List[str], str]): diff --git a/robyn/responses.py b/robyn/responses.py index 6f2e454d5..41b8e4f0d 100644 --- a/robyn/responses.py +++ b/robyn/responses.py @@ -1,6 +1,6 @@ import mimetypes import os -from typing import Optional, Any, Union, Callable, Iterator, AsyncIterator, Dict +from typing import Optional, Union, Iterator, AsyncIterator, Dict from robyn.robyn import Headers, Response diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index fa8da141d..dcb035fc9 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -287,7 +287,7 @@ class Response: status_code (int): The status code of the response. e.g. 200, 404, 500 etc. response_type (Optional[str]): The response type of the response. e.g. text, json, html, file etc. headers (Union[Headers, dict]): The headers of the response or Headers directly. e.g. {"Content-Type": "application/json"} - description (Union[str, bytes, Iterator[bytes], AsyncIterator[bytes], Generator[bytes, None, None], AsyncGenerator[bytes, None]]): + description (Union[str, bytes, Iterator[bytes], AsyncIterator[bytes], Generator[bytes, None, None], AsyncGenerator[bytes, None]]): The body of the response. Can be: - str: Plain text response - bytes: Binary response diff --git a/robyn/router.py b/robyn/router.py index 94fbbc7f6..87c529d8d 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -317,7 +317,6 @@ def add_auth_middleware(self, endpoint: str): """ injected_dependencies: dict = {} - def decorator(handler): @wraps(handler) @@ -331,7 +330,7 @@ def inner_handler(request: Request, *args): status_code=401, headers=Headers({"WWW-Authenticate": self.authentication_handler.token_getter.scheme}), description=b"Unauthorized", # Use bytes to ensure proper type conversion - streaming=False + streaming=False, ) request.identity = identity return request