Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tunable: Allow empty default lists when type-hinted #212

Merged
merged 5 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions magicbot/magic_tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def execute(self):
"_ntsubtable",
"_ntwritedefault",
# "__doc__",
"__orig_class__",
"_topic_type",
"_nt",
)
Expand All @@ -100,13 +101,48 @@ def __init__(
self._ntwritedefault = writeDefault
# self.__doc__ = doc

self._topic_type = _get_topic_type_for_value(self._ntdefault)
if self._topic_type is None:
checked_type: type = type(self._ntdefault)
# Defer checks for empty sequences to check type hints.
# Report errors here when we can so the error points to the tunable line.
if default or not isinstance(default, collections.abc.Sequence):
topic_type = _get_topic_type_for_value(default)
if topic_type is None:
checked_type: type = type(default)
raise TypeError(
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
)
self._topic_type = topic_type

def __set_name__(self, owner: type, name: str) -> None:
type_hint: Optional[type] = None
# __orig_class__ is set after __init__, check it here.
orig_class = getattr(self, "__orig_class__", None)
if orig_class is not None:
# Accept field = tunable[Sequence[int]]([])
type_hint = typing.get_args(orig_class)[0]
else:
type_hint = typing.get_type_hints(owner).get(name)
origin = typing.get_origin(type_hint)
if origin is typing.ClassVar:
# Accept field: ClassVar[tunable[Sequence[int]]] = tunable([])
type_hint = typing.get_args(type_hint)[0]
origin = typing.get_origin(type_hint)
if origin is tunable:
# Accept field: tunable[Sequence[int]] = tunable([])
type_hint = typing.get_args(type_hint)[0]

if type_hint is not None:
topic_type = _get_topic_type(type_hint)
else:
topic_type = _get_topic_type_for_value(self._ntdefault)

if topic_type is None:
checked_type: type = type_hint or type(self._ntdefault)
raise TypeError(
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
)

self._topic_type = topic_type

@overload
def __get__(self, instance: None, owner=None) -> "tunable[V]": ...

Expand Down Expand Up @@ -218,7 +254,7 @@ class MyComponent:
navx: ...

@feedback
def get_angle(self):
def get_angle(self) -> float:
return self.navx.getYaw()

class MyRobot(magicbot.MagicRobot):
Expand Down Expand Up @@ -297,6 +333,8 @@ def _get_topic_type(
if hasattr(inner_type, "WPIStruct"):
return lambda topic: ntcore.StructArrayTopic(topic, inner_type)

return None


def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"):
"""
Expand Down
44 changes: 43 additions & 1 deletion tests/test_magicbot_tunable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import ClassVar, List, Sequence

import ntcore
import pytest
from wpimath import geometry
Expand Down Expand Up @@ -25,6 +27,7 @@ class Component:
topic = nt.getTopic(name)
assert topic.getTypeString() == type_str
assert topic.genericSubscribe().get().value() == value
assert getattr(component, name) == value

for name, value in [
("rotation", geometry.Rotation2d()),
Expand All @@ -33,13 +36,15 @@ class Component:
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}"
topic = nt.getStructTopic(name, struct_type)
assert topic.subscribe(None).get() == value
assert getattr(component, name) == value

for name, struct_type, value in [
("rotations", geometry.Rotation2d, [geometry.Rotation2d()]),
]:
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]"
topic = nt.getStructArrayTopic(name, struct_type)
assert topic.subscribe([]).get() == value
assert getattr(component, name) == value


def test_tunable_errors():
Expand All @@ -50,7 +55,44 @@ class Component:


def test_tunable_errors_with_empty_sequence():
with pytest.raises(ValueError):
with pytest.raises((RuntimeError, ValueError)):

class Component:
empty = tunable([])


def test_type_hinted_empty_sequences() -> None:
class Component:
generic_seq = tunable[Sequence[int]](())
class_var_seq: ClassVar[tunable[Sequence[int]]] = tunable(())
inst_seq: Sequence[int] = tunable(())

generic_typing_list = tunable[List[int]]([])
class_var_typing_list: ClassVar[tunable[List[int]]] = tunable([])
inst_typing_list: List[int] = tunable([])

# TODO(davo): re-enable after py3.8 is dropped
virtuald marked this conversation as resolved.
Show resolved Hide resolved
# generic_list = tunable[list[int]]([])
# class_var_list: ClassVar[tunable[list[int]]] = tunable([])
# inst_list: list[int] = tunable([])

component = Component()
setup_tunables(component, "test_type_hinted_sequences")
NetworkTables = ntcore.NetworkTableInstance.getDefault()
nt = NetworkTables.getTable("/components/test_type_hinted_sequences")

for name in [
"generic_seq",
"class_var_seq",
"inst_seq",
"generic_typing_list",
"class_var_typing_list",
"inst_typing_list",
# "generic_list",
# "class_var_list",
# "inst_list",
]:
assert nt.getTopic(name).getTypeString() == "int[]"
entry = nt.getEntry(name)
assert entry.getIntegerArray(None) == []
assert getattr(component, name) == []
Loading