Skip to content

Commit

Permalink
Fix static type shape bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and michaelosthege committed Oct 11, 2023
1 parent 081a0b4 commit 6834740
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def extract_batch_shape(p, ps, n):
return shape

batch_shape = [
s if b is False else constant(1, "int64")
s if not b else constant(1, "int64")
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
]
return batch_shape
Expand Down
7 changes: 6 additions & 1 deletion pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,13 @@ def __init__(
def parse_bcast_and_shape(s):
if isinstance(s, (bool, np.bool_)):
return 1 if s else None
else:
elif isinstance(s, (int, np.int_)):
return int(s)
elif s is None:
return s
raise ValueError(
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
)

self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.dtype_specs() # error checking is done there
Expand Down
10 changes: 10 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
_gamma,
bernoulli,
Expand Down Expand Up @@ -1465,3 +1466,12 @@ def test_rebuild():
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)


def test_categorical_join_p_static_shape():
"""Regression test against a bug caused by misreading a numpy.bool_"""
p = ones(3) / 3
prob = stack([p, 1 - p], axis=-1)
assert prob.type.shape == (3, 2)
x = categorical(p=prob)
assert x.type.shape == (3,)
15 changes: 11 additions & 4 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,17 +2046,24 @@ def test_mixed_ndim_error(self):
def test_static_shape_inference(self):
a = at.tensor(dtype="int8", shape=(2, 3))
b = at.tensor(dtype="int8", shape=(2, 5))
assert at.join(1, a, b).type.shape == (2, 8)
assert at.join(-1, a, b).type.shape == (2, 8)

res = at.join(1, a, b).type.shape
assert res == (2, 8)
assert all(isinstance(s, int) for s in res)

res = at.join(-1, a, b).type.shape
assert res == (2, 8)
assert all(isinstance(s, int) for s in res)

# Check early informative errors from static shape info
with pytest.raises(ValueError, match="must match exactly"):
at.join(0, at.ones((2, 3)), at.ones((2, 5)))

# Check partial inference
d = at.tensor(dtype="int8", shape=(2, None))
assert at.join(1, a, b, d).type.shape == (2, None)
return
res = at.join(1, a, b, d).type.shape
assert res == (2, None)
assert isinstance(res[0], int)

def test_split_0elem(self):
rng = np.random.default_rng(seed=utt.fetch_seed())
Expand Down
21 changes: 21 additions & 0 deletions tests/tensor/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,27 @@ def test_fixed_shape_basic():
assert t2.shape == (2, 4)


def test_shape_type_conversion():
t1 = TensorType("float64", shape=np.array([3], dtype=int))
assert t1.shape == (3,)
assert isinstance(t1.shape[0], int)
assert t1.broadcastable == (False,)
assert isinstance(t1.broadcastable[0], bool)

t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool"))
assert t2.shape == (1, None)
assert isinstance(t2.shape[0], int)
assert t2.broadcastable == (True, False)
assert isinstance(t2.broadcastable[0], bool)
assert isinstance(t2.broadcastable[1], bool)

with pytest.raises(
ValueError,
match="TensorType broadcastable/shape must be a boolean, integer or None",
):
TensorType("float64", shape=("1", "2"))


def test_fixed_shape_clone():
t1 = TensorType("float64", (1,))

Expand Down

0 comments on commit 6834740

Please sign in to comment.