Skip to content

Commit

Permalink
add more type annotations through the code (#4401)
Browse files Browse the repository at this point in the history
* add more type annotations through the code

* add typing in reflex/utils

* misc typing

* more typings

* state typing

* keep typing

* typing init and utils

* more typing for components

* fix attempt for 3.9

* need more __future

* more typings

* type event plz

* type model

* type vars/base.py

* enable 'ANN001' for reflex folder (ignore tests and benchmarks)

* fix pyi

* add missing annotations

* use more precise error when ignoring
  • Loading branch information
Lendemor authored Jan 29, 2025
1 parent 2a92221 commit b8b3f89
Show file tree
Hide file tree
Showing 54 changed files with 286 additions and 214 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ reportIncompatibleMethodOverride = false
target-version = "py310"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"]
lint.pydocstyle.convention = "google"

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*.py" = ["D100", "D103", "D104", "B018", "PERF", "T", "N"]
"benchmarks/*.py" = ["D100", "D103", "D104", "B018", "PERF", "T", "N"]
"tests/*.py" = ["ANN001", "D100", "D103", "D104", "B018", "PERF", "T", "N"]
"benchmarks/*.py" = ["ANN001", "D100", "D103", "D104", "B018", "PERF", "T", "N"]
"reflex/.templates/*.py" = ["D100", "D103", "D104"]
"*.pyi" = ["D301", "D415", "D417", "D418", "E742", "N", "PGH"]
"pyi_generator.py" = ["N802"]
Expand Down
5 changes: 4 additions & 1 deletion reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@

from __future__ import annotations

from types import ModuleType
from typing import Any

from reflex.utils import (
compat, # for side-effects
lazy_loader,
Expand Down Expand Up @@ -365,5 +368,5 @@
)


def __getattr__(name):
def __getattr__(name: ModuleType | Any):
return getattr(name)
37 changes: 18 additions & 19 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,15 @@ def _setup_state(self) -> None:
if self.api:

class HeaderMiddleware:
def __init__(self, app):
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(
self, scope: MutableMapping[str, Any], receive, send
self, scope: MutableMapping[str, Any], receive: Any, send: Callable
):
original_send = send

async def modified_send(message):
async def modified_send(message: dict):
if message["type"] == "websocket.accept":
if scope.get("subprotocols"):
# The following *does* say "subprotocol" instead of "subprotocols", intentionally.
Expand Down Expand Up @@ -712,8 +712,8 @@ def add_custom_404_page(
Args:
component: The component to display at the page.
title: The title of the page.
description: The description of the page.
image: The image to display on the page.
description: The description of the page.
on_load: The event handler(s) that will be called each time the page load.
meta: The metadata of the page.
"""
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def get_compilation_time() -> str:
with executor:
result_futures = []

def _submit_work(fn, *args, **kwargs):
def _submit_work(fn: Callable, *args, **kwargs):
f = executor.submit(fn, *args, **kwargs)
result_futures.append(f)

Expand Down Expand Up @@ -1387,15 +1387,14 @@ async def process(
if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True)
return

# Process the event synchronously.
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)

# Yield the update.
yield update
else:
# Process the event synchronously.
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)

# Yield the update.
yield update
except Exception as ex:
telemetry.send_error(ex, context="backend")

Expand Down Expand Up @@ -1590,20 +1589,20 @@ def __init__(self, namespace: str, app: App):
self.sid_to_token = {}
self.app = app

def on_connect(self, sid, environ):
def on_connect(self, sid: str, environ: dict):
"""Event for when the websocket is connected.
Args:
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None)
subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL")
if subprotocol and subprotocol != constants.Reflex.VERSION:
console.warn(
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
)

def on_disconnect(self, sid):
def on_disconnect(self, sid: str):
"""Event for when the websocket disconnects.
Args:
Expand All @@ -1625,7 +1624,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
)

async def on_event(self, sid, data):
async def on_event(self, sid: str, data: Any):
"""Event for receiving front-end websocket events.
Raises:
Expand Down Expand Up @@ -1692,7 +1691,7 @@ async def on_event(self, sid, data):
# Emit the update from processing the event.
await self.emit_update(update=update, sid=sid)

async def on_ping(self, sid):
async def on_ping(self, sid: str):
"""Event for testing the API endpoint.
Args:
Expand Down
2 changes: 1 addition & 1 deletion reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
Args:
task: The task to register.
task_kwargs: The kwargs of the task.
**task_kwargs: The kwargs of the task.
Raises:
InvalidLifespanTaskTypeError: If the task is a generator function.
Expand Down
2 changes: 1 addition & 1 deletion reflex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def json(self) -> str:
default=serialize,
)

def set(self, **kwargs):
def set(self, **kwargs: Any):
"""Set multiple fields and return the object.
Args:
Expand Down
2 changes: 1 addition & 1 deletion reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def empty_dir(path: str | Path, keep_files: list[str] | None = None):
path_ops.rm(element)


def is_valid_url(url) -> bool:
def is_valid_url(url: str) -> bool:
"""Check if a url is valid.
Args:
Expand Down
10 changes: 5 additions & 5 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def __init__(self, *args, **kwargs):
else:
continue

def determine_key(value):
def determine_key(value: Any):
# Try to create a var from the value
key = value if isinstance(value, Var) else LiteralVar.create(value)

Expand Down Expand Up @@ -707,7 +707,7 @@ def create(cls, *children, **props) -> Component:
# Filter out None props
props = {key: value for key, value in props.items() if value is not None}

def validate_children(children):
def validate_children(children: tuple | list):
for child in children:
if isinstance(child, (tuple, list)):
validate_children(child)
Expand Down Expand Up @@ -851,7 +851,7 @@ def _get_style(self) -> dict:
else {}
)

def render(self) -> Dict:
def render(self) -> dict:
"""Render the component.
Returns:
Expand All @@ -869,7 +869,7 @@ def render(self) -> Dict:
self._replace_prop_names(rendered_dict)
return rendered_dict

def _replace_prop_names(self, rendered_dict) -> None:
def _replace_prop_names(self, rendered_dict: dict) -> None:
"""Replace the prop names in the render dictionary.
Args:
Expand Down Expand Up @@ -909,7 +909,7 @@ def _validate_component_children(self, children: List[Component]):
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
]

def validate_child(child):
def validate_child(child: Any):
child_name = type(child).__name__

# Iterate through the immediate children of fragment
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/client_side_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def render(self) -> str:
return ""


def wait_for_client_redirect(component) -> Component:
def wait_for_client_redirect(component: Component) -> Component:
"""Wait for a redirect to occur before rendering a component.
This prevents the 404 page from flashing while the redirect is happening.
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/client_side_routing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ClientSideRouting(Component):
"""
...

def wait_for_client_redirect(component) -> Component: ...
def wait_for_client_redirect(component: Component) -> Component: ...

class Default404Page(Component):
@overload
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")

def create_var(cond_part):
def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part)

# convert the truth and false cond parts into vars so the _var_data can be obtained.
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _process_cases(
return cases, default

@classmethod
def _create_case_var_with_var_data(cls, case_element):
def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
"""Convert a case element into a Var.If the case
is a Style type, we extract the var data and merge it with the
newly created Var.
Expand Down
6 changes: 2 additions & 4 deletions reflex/components/datadisplay/logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ def svg_logo(color: Union[str, rx.Var[str]] = rx.color_mode_cond("#110F1F", "whi
The Reflex logo SVG.
"""

def logo_path(d):
return rx.el.svg.path(
d=d,
)
def logo_path(d: str):
return rx.el.svg.path(d=d)

paths = [
"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z",
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/el/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Element(Component):
"""The base class for all raw HTML elements."""

def __eq__(self, other):
def __eq__(self, other: object):
"""Two elements are equal if they have the same tag.
Args:
Expand Down
8 changes: 5 additions & 3 deletions reflex/components/markdown/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union

from reflex.components.component import Component, CustomComponent
from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
Expand Down Expand Up @@ -379,7 +379,9 @@ def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var:
# fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)

def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
def _get_map_fn_custom_code_from_children(
self, component: BaseComponent
) -> list[str]:
"""Recursively get markdown custom code from children components.
Args:
Expand Down Expand Up @@ -409,7 +411,7 @@ def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
return custom_code_list

@staticmethod
def _component_map_hash(component_map) -> str:
def _component_map_hash(component_map: dict) -> str:
inp = str(
{tag: component(_MOCK_ARG) for tag, component in component_map.items()}
).encode()
Expand Down
4 changes: 3 additions & 1 deletion reflex/components/next/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Image component from next/image."""

from __future__ import annotations

from typing import Any, Literal, Optional, Union

from reflex.event import EventHandler, no_args_event_spec
Expand Down Expand Up @@ -93,7 +95,7 @@ def create(

style = props.get("style", {})

def check_prop_type(prop_name, prop_value):
def check_prop_type(prop_name: str, prop_value: int | str | None):
if types.check_prop_in_allowed_types(prop_value, allowed_types=[int]):
props[prop_name] = prop_value

Expand Down
2 changes: 1 addition & 1 deletion reflex/components/props.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dict(self, *args, **kwargs):
class NoExtrasAllowedProps(Base):
"""A class that holds props to be passed or applied to a component with no extra props allowed."""

def __init__(self, component_name=None, **kwargs):
def __init__(self, component_name: str | None = None, **kwargs):
"""Initialize the props.
Args:
Expand Down
12 changes: 7 additions & 5 deletions reflex/components/radix/themes/color_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Dict, List, Literal, Optional, Union, get_args
from typing import Any, Dict, List, Literal, Optional, Union, get_args

from reflex.components.component import BaseComponent
from reflex.components.core.cond import Cond, color_mode_cond, cond
Expand Down Expand Up @@ -78,17 +78,19 @@ def create(


# needed to inverse contains for find
def _find(const: List[str], var):
def _find(const: List[str], var: Any):
return LiteralArrayVar.create(const).contains(var)


def _set_var_default(props, position, prop, default1, default2=""):
def _set_var_default(
props: dict, position: Any, prop: str, default1: str, default2: str = ""
):
props.setdefault(
prop, cond(_find(position_map[prop], position), default1, default2)
)


def _set_static_default(props, position, prop, default):
def _set_static_default(props: dict, position: Any, prop: str, default: str):
if prop in position:
props.setdefault(prop, default)

Expand Down Expand Up @@ -142,7 +144,7 @@ def create(

if allow_system:

def color_mode_item(_color_mode):
def color_mode_item(_color_mode: str):
return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode)
)
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/radix/themes/layout/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class List(ComponentNamespace):
unordered_list = list_ns.unordered


def __getattr__(name):
def __getattr__(name: Any):
# special case for when accessing list to avoid shadowing
# python's built in list object.
if name == "list":
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/recharts/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _ensure_valid_dimension(name: str, value: Any) -> None:
)

@classmethod
def create(cls, *children, **props) -> Component:
def create(cls, *children: Any, **props: Any) -> Component:
"""Create a chart component.
Args:
Expand Down
Loading

0 comments on commit b8b3f89

Please sign in to comment.