Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make Signal and SignalInstance Generic, support static type validation of signal connections #304

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
14 changes: 13 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: pipx run check-manifest
- run: |
pipx run check-manifest

check-templates:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- run: |
pip install ruff mypy
CHECK_STUBS=1 python scripts/build_stub.py

test:
name: Test
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
rev: v1.11.2
hooks:
- id: mypy
exclude: tests|_throttler.pyi
exclude: tests|_throttler.pyi|.*_signal.pyi
additional_dependencies:
- types-attrs
- pydantic
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ select = [
"RUF", # ruff-specific rules
]
ignore = [
"D401", # First line should be in imperative mood
"D401", # First line should be in imperative mood
]

[tool.ruff.lint.per-file-ignores]
Expand Down Expand Up @@ -182,6 +182,7 @@ disallow_subclassing_any = false
show_error_codes = true
pretty = true


[[tool.mypy.overrides]]
module = ["numpy.*", "wrapt", "pydantic.*"]
ignore_errors = true
Expand Down Expand Up @@ -215,6 +216,7 @@ omit = ["*/_pyinstaller_util/*"]
ignore = [
".ruff_cache/**/*",
".github_changelog_generator",
"scripts/*",
".pre-commit-config.yaml",
"tests/**/*",
"typesafety/*",
Expand Down
132 changes: 132 additions & 0 deletions scripts/build_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Build _signal.pyi with def connect @overloads."""

import os
import re
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import indent

ROOT = Path(__file__).parent.parent / "src" / "psygnal"
TEMPLATE_PATH = ROOT / "_signal.py.jinja2"
DEST_PATH = TEMPLATE_PATH.with_suffix("")

# Maximum number of arguments allowed in callbacks
MAX_ARGS = 5


@dataclass
class Arg:
"""Single arg."""

name: str
hint: str
default: str | None = None


@dataclass
class Sig:
"""Full signature."""

arguments: list[Arg]
return_hint: str

def render(self) -> str:
"""Render the signature as a def connect overload."""
args = ", ".join(f"{arg.name}: {arg.hint}" for arg in self.arguments) + ","
args += """
*,
thread: threading.Thread | Literal["main", "current"] | None = None,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
max_args: int | None = None,
on_ref_error: RefErrorChoice = "warn",
priority: int = 0,
"""
return f"\n@overload\ndef connect({args}) -> {self.return_hint}: ..."


connect_overloads: list[Sig] = []
for nself in range(MAX_ARGS + 1):
for ncallback in range(nself + 1):
if nself:
self_types = ", ".join(f"type[_T{i+1}]" for i in range(nself))
else:
self_types = "()"
arg_types = ", ".join(f"_T{i+1}" for i in range(ncallback))
slot_type = f"Callable[[{arg_types}], RetT]"
connect_overloads.append(
Sig(
arguments=[
Arg(name="self", hint=f"SignalInstance[{self_types}]"),
Arg(name="slot", hint=slot_type),
],
return_hint=slot_type,
)
)

connect_overloads.append(
Sig(
arguments=[
Arg(name="self", hint="SignalInstance[Unparametrized]"),
Arg(name="slot", hint="F"),
],
return_hint="F",
)
)
connect_overloads.append(
Sig(
arguments=[
Arg(name="self", hint="SignalInstance"),
],
return_hint="Callable[[F], F]",
)
)


STUB = Path("src/psygnal/_signal.pyi")


if __name__ == "__main__":
existing_stub = STUB.read_text() if STUB.exists() else None

# make a temporary file to write to
with TemporaryDirectory() as tmpdir:
subprocess.run(
[ # noqa
"stubgen",
"--include-private",
# "--include-docstrings",
"src/psygnal/_signal.py",
"-o",
tmpdir,
]
)
stub_path = Path(tmpdir) / "psygnal" / "_signal.pyi"
new_stub = "from typing import NewType\n" + stub_path.read_text()
new_stub = new_stub.replace(
"ReemissionVal: Incomplete",
'ReemissionVal = Literal["immediate", "queued", "latest-only"]',
)
new_stub = new_stub.replace(
"Unparametrized: Incomplete",
'Unparametrized = NewType("Unparametrized", object)',
)
overloads = "\n".join(sig.render() for sig in connect_overloads)
overloads = indent(overloads, " ")
new_stub = re.sub(r"def connect.+\.\.\.", overloads, new_stub)

stub_path.write_text(new_stub)
subprocess.run(["ruff", "format", tmpdir]) # noqa
subprocess.run(["ruff", "check", tmpdir, "--fix"]) # noqa
new_stub = stub_path.read_text()

if os.getenv("CHECK_STUBS"):
if existing_stub != new_stub:
raise RuntimeError(f"{STUB} content not up to date.")
sys.exit(0)

STUB.write_text(new_stub)
74 changes: 74 additions & 0 deletions scripts/render_connect_overloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Render @overload for SignalInstance.connect."""

import os
import subprocess
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile

from jinja2 import Template

ROOT = Path(__file__).parent.parent / "src" / "psygnal"
TEMPLATE_PATH = ROOT / "_signal.py.jinja2"
DEST_PATH = TEMPLATE_PATH.with_suffix("")

# Maximum number of arguments allowed in callbacks
MAX_ARGS = 5


@dataclass
class Arg:
"""Single arg."""

name: str
hint: str
default: str | None = None


@dataclass
class Sig:
"""Full signature."""

arguments: list[Arg]
return_hint: str


connect_overloads: list[Sig] = []
for nself in range(MAX_ARGS + 1):
for ncallback in range(nself + 1):
if nself:
self_types = ", ".join(f"type[_T{i+1}]" for i in range(nself))
else:
self_types = "()"
arg_types = ", ".join(f"_T{i+1}" for i in range(ncallback))
slot_type = f"Callable[[{arg_types}], RetT]"
connect_overloads.append(
Sig(
arguments=[
Arg(name="self", hint=f"SignalInstance[{self_types}]"),
Arg(name="slot", hint=slot_type),
],
return_hint=slot_type,
)
)

template: Template = Template(TEMPLATE_PATH.read_text())
result = template.render(number_of_types=MAX_ARGS, connect_overloads=connect_overloads)

result = (
"# WARNING: do not modify this code, it is generated by "
f"{TEMPLATE_PATH.name}\n\n" + result
)

# make a temporary file to write to
with NamedTemporaryFile(suffix=".py") as tmp:
Path(tmp.name).write_text(result)
subprocess.run(["ruff", "format", tmp.name]) # noqa
subprocess.run(["ruff", "check", tmp.name, "--fix"]) # noqa
result = Path(tmp.name).read_text()

current_content = DEST_PATH.read_text() if DEST_PATH.exists() else ""
if current_content != result and os.getenv("CHECK_JINJA"):
raise RuntimeError(f"{DEST_PATH} content not up to date with {TEMPLATE_PATH.name}")

DEST_PATH.write_text(result)
22 changes: 15 additions & 7 deletions src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@
Literal,
Mapping,
NamedTuple,
Type,
overload,
)

from psygnal._signal import _NULL, Signal, SignalInstance, _SignalBlocker
from psygnal._signal import (
_NULL,
Signal,
SignalInstance,
Unparametrized,
_SignalBlocker,
)

from ._mypyc import mypyc_attr

Expand All @@ -36,6 +43,7 @@
from psygnal._signal import F, ReducerFunc
from psygnal._weak_callback import RefErrorChoice, WeakCallback


__all__ = ["EmissionInfo", "SignalGroup"]


Expand All @@ -52,7 +60,7 @@ class EmissionInfo(NamedTuple):
args: tuple[Any, ...]


class SignalRelay(SignalInstance):
class SignalRelay(SignalInstance[Type[EmissionInfo]]):
"""Special SignalInstance that can be used to connect to all signals in a group.

This class will rarely be instantiated by a user (or anything other than a
Expand All @@ -69,7 +77,7 @@ class SignalRelay(SignalInstance):
def __init__(
self, signals: Mapping[str, SignalInstance], instance: Any = None
) -> None:
super().__init__(signature=(EmissionInfo,), instance=instance)
super().__init__((EmissionInfo,), instance=instance)
self._signals = MappingProxyType(signals)
self._sig_was_blocked: dict[str, bool] = {}

Expand Down Expand Up @@ -381,15 +389,15 @@ def __len__(self) -> int:
"""Return the number of signals in the group (not including the relay)."""
return len(self._psygnal_instances)

def __getitem__(self, item: str) -> SignalInstance:
def __getitem__(self, item: str) -> SignalInstance[Unparametrized]:
"""Get a signal instance by name."""
return self._psygnal_instances[item]

# this is just here for type checking, particularly on cases
# where the SignalGroup comes from the SignalGroupDescriptor
# (such as in evented dataclasses). In those cases, it's hard to indicate
# to mypy that all remaining attributes are SignalInstances.
def __getattr__(self, __name: str) -> SignalInstance:
def __getattr__(self, __name: str) -> SignalInstance[Unparametrized]:
"""Get a signal instance by name."""
raise AttributeError( # pragma: no cover
f"{type(self).__name__!r} object has no attribute {__name!r}"
Expand Down Expand Up @@ -466,7 +474,7 @@ def connect(

def connect(
self,
slot: F | None = None,
slot: Callable | None = None,
*,
thread: threading.Thread | Literal["main", "current"] | None = None,
check_nargs: bool | None = None,
Expand All @@ -475,7 +483,7 @@ def connect(
max_args: int | None = None,
on_ref_error: RefErrorChoice = "warn",
priority: int = 0,
) -> Callable[[F], F] | F:
) -> Callable:
if slot is None:
return self._psygnal_relay.connect(
thread=thread,
Expand Down
Loading
Loading