diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index c33771f9c..28ac905a2 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -775,6 +775,9 @@ def eq(self, other): :class:`.Assign` :py:`self.as_value().eq(other)` """ + if isinstance(other, ValueCastable): + if not self.shape() == Layout.cast(other.shape()): + raise TypeError(f"Cannot assign value with shape {other.shape()} to view with layout {self.shape()}") return self.as_value().eq(other) def __getitem__(self, key): diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index 0beadacb7..043a0e5b7 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -248,6 +248,9 @@ def eq(self, other): :class:`Assign` ``self.as_value().eq(other)`` """ + if isinstance(other, ValueCastable): + if not self.shape() == other.shape(): + raise TypeError(f"Cannot assign value with shape {other.shape()} to value with shape {self.shape()}") return self.as_value().eq(other) def __add__(self, other): diff --git a/amaranth/lib/wiring.py b/amaranth/lib/wiring.py index ef0a7499d..67ddcdb7e 100644 --- a/amaranth/lib/wiring.py +++ b/amaranth/lib/wiring.py @@ -2,6 +2,7 @@ import enum import re import warnings +import inspect try: import annotationlib # py3.14+ except ImportError: @@ -1562,21 +1563,26 @@ def connect(m, *args, **kwargs): (out_path, out_member), = out_kind for (in_path, in_member) in in_kind: def connect_value(*, out_path, in_path, src_loc_at): - in_value = Value.cast(_traverse_path(in_path, objects)) - out_value = Value.cast(_traverse_path(out_path, objects)) - assert type(in_value) in (Const, Signal) + in_value = _traverse_path(in_path, objects) + out_value = _traverse_path(out_path, objects) # If the input is a constant, only a constant may be connected to it. Ensure that # this is the case. + try: + in_value = Const.cast(in_value) + except TypeError: + pass if type(in_value) is Const: # If the output is not a constant, the connection is illegal. - if type(out_value) is not Const: + try: + out_value = Const.cast(out_value) + except TypeError: raise ConnectionError( f"Cannot connect input member {_format_path(in_path)} that has " f"a constant value {in_value.value!r} to an output member " f"{_format_path(out_path)} that has a varying value") # If the output is a constant, the connection is legal only if the value is # the same for both the input and the output. - if type(out_value) is Const and in_value.value != out_value.value: + if in_value.value != out_value.value: raise ConnectionError( f"Cannot connect input member {_format_path(in_path)} that has " f"a constant value {in_value.value!r} to an output member " @@ -1586,8 +1592,24 @@ def connect_value(*, out_path, in_path, src_loc_at): # value (which is constant) is consistent with a connection that would have # been made. return - # A connection that is made at this point is guaranteed to be valid. - connections.append(in_value.eq(out_value, src_loc_at=src_loc_at + 1)) + # If the input is a ValueCastable, it must implement `eq()`. + try: + eq = in_value.eq + except AttributeError: + raise ConnectionError( + f"Cannot connect input member {_format_path(in_path)} because the input " + f"value {in_value!r} does not support assignment") + # The `eq()` method may take a `src_loc_at` argument; provide it if it does. + if 'src_loc_at' in inspect.signature(eq).parameters: + kwargs = {'src_loc_at': src_loc_at + 1} + else: + kwargs = {} + try: + connections.append(eq(out_value, **kwargs)) + except Exception as e: + raise ConnectionError( + f"Cannot connect input member {_format_path(in_path)} to output member " + f"{_format_path(out_path)} because assignment failed") from e def connect_dimensions(dimensions, *, out_path, in_path, src_loc_at): if not dimensions: return connect_value(out_path=out_path, in_path=in_path, src_loc_at=src_loc_at) diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py index eb2639047..9fcce5dce 100644 --- a/tests/test_lib_wiring.py +++ b/tests/test_lib_wiring.py @@ -851,6 +851,24 @@ class Cycle(enum.Enum, shape=2): q=NS(signature=Signature({"a": In(Cycle)}), a=Signal(Cycle))) + def test_shape_mismatch_layout(self): + class LastDelimited(data.Struct): + data: 8 + last: 1 + class FirstDelimited(data.Struct): + data: 8 + first: 1 + + m = Module() + with self.assertRaisesRegex(ConnectionError, + r"^Cannot connect input member 'q\.a' to output member 'p\.a' because assignment " + r"failed$"): + connect(m, + p=NS(signature=Signature({"a": Out(LastDelimited)}), + a=Signal(LastDelimited)), + q=NS(signature=Signature({"a": In(FirstDelimited)}), + a=Signal(FirstDelimited))) + def test_init_mismatch(self): m = Module() with self.assertRaisesRegex(ConnectionError,