Skip to content

Commit

Permalink
solve some but not all pyright issues
Browse files Browse the repository at this point in the history
adhami3310 committed Jan 17, 2025
1 parent 3d73f56 commit 112b2ed
Showing 3 changed files with 53 additions and 24 deletions.
42 changes: 30 additions & 12 deletions reflex/utils/pyi_generator.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from multiprocessing import Pool, cpu_count
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin

from reflex.components.component import Component
from reflex.utils import types as rx_types
@@ -229,7 +229,9 @@ def _generate_imports(
"""
return [
*[
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
ast.ImportFrom(
module=name, names=[ast.alias(name=val) for val in values], level=0
)
for name, values in DEFAULT_IMPORTS.items()
],
ast.Import([ast.alias("reflex")]),
@@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST:
return ast.Name(id=base_name)

# Convert all type arguments recursively
arg_nodes = [type_to_ast(arg, cls) for arg in args]
arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])

# Special case for single-argument types (like List[T] or Optional[T])
if len(arg_nodes) == 1:
slice_value = arg_nodes[0]
else:
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())

return ast.Subscript(
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
)


@@ -630,7 +631,7 @@ def figure_out_return_type(annotation: Any):
),
),
ast.Expr(
value=ast.Ellipsis(),
value=ast.Constant(...),
),
],
decorator_list=[
@@ -641,8 +642,14 @@ def figure_out_return_type(annotation: Any):
else [ast.Name(id="classmethod")]
),
],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(value=clz.__name__),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
)
return definition

@@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef(
),
],
decorator_list=[ast.Name(id="staticmethod")],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(
value=_get_type_hint(
typing.get_type_hints(clz.__call__).get("return", None),
type_hint_globals,
)
),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
)
return definition

@@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef(
# Determine which class is wrapped by the namespace __call__ method
component_clz = clz.__call__.__self__

if clz.__call__.__func__.__name__ != "create":
func = getattr(clz.__call__, "__func__", None)

if func is None:
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")

if func.__name__ != "create":
return None

definition = _generate_component_create_functiondef(
@@ -914,7 +932,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
node.body.append(call_definition)
if not node.body:
# We should never return an empty body.
node.body.append(ast.Expr(value=ast.Ellipsis()))
node.body.append(ast.Expr(value=ast.Constant(...)))
self.current_class = None
return node

@@ -941,9 +959,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
if node.name.startswith("_") and node.name != "__call__":
return None # remove private methods

if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
if node.body[-1] != ast.Expr(value=ast.Constant(...)):
# Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Ellipsis())]
node.body = [ast.Expr(value=ast.Constant(...))]
return node

def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
21 changes: 14 additions & 7 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
@@ -69,21 +69,21 @@ def override(func: Callable) -> Callable:


# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
_GenericAliasTypes: list[type] = [_GenericAlias]

with contextlib.suppress(ImportError):
# For newer versions of Python.
from types import GenericAlias # type: ignore

GenericAliasTypes.append(GenericAlias)
_GenericAliasTypes.append(GenericAlias)

with contextlib.suppress(ImportError):
# For older versions of Python.
from typing import _SpecialGenericAlias # type: ignore

GenericAliasTypes.append(_SpecialGenericAlias)
_GenericAliasTypes.append(_SpecialGenericAlias)

GenericAliasTypes = tuple(GenericAliasTypes)
GenericAliasTypes = tuple(_GenericAliasTypes)

# Potential Union types for isinstance checks (UnionType added in py3.10).
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
@@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes)


def unionize(*args: GenericType) -> Type:
def unionize(*args: GenericType) -> GenericType:
"""Unionize the types.
Args:
@@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None


@lru_cache()
def get_base_class(cls: GenericType) -> Type:
def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
"""Get the base class of a class.
Args:
@@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type:
return type(get_args(cls)[0])

if is_union(cls):
return tuple(get_base_class(arg) for arg in get_args(cls))
base_classes = []
for arg in get_args(cls):
sub_base_classes = get_base_class(arg)
if isinstance(sub_base_classes, tuple):
base_classes.extend(sub_base_classes)
else:
base_classes.append(sub_base_classes)
return tuple(base_classes)

return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls

14 changes: 9 additions & 5 deletions reflex/vars/number.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)

@@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
default_factory=tuple
)
_default: Var[VAR_TYPE] = dataclasses.field(
_default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
default_factory=lambda: Var.create(None)
)

@@ -1170,19 +1171,22 @@ def create(
The match operation.
"""
cond = Var.create(cond)
cases = tuple(tuple(Var.create(c) for c in case) for case in cases)
default = Var.create(default)
cases = cast(
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases),
)
_default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases),
default._var_type,
_default._var_type,
)
return cls(
_js_expr="",
_var_data=_var_data,
_var_type=var_type,
_cond=cond,
_cases=cases,
_default=default,
_default=_default,
)


0 comments on commit 112b2ed

Please sign in to comment.