-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_typecheck.py
157 lines (111 loc) · 3.7 KB
/
test_typecheck.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from contextlib import nullcontext as does_not_raise
from functools import partial
import pytest
from awfutils import typecheck
typecheck_show_src = partial(typecheck, show_src=True)
def test_typecheck_1():
@typecheck
def foo(x: int, t: float) -> float:
y: float = x * t
assert isinstance(y, float), f"y : {type(y)} not of type float"
z: int = x // 2
assert isinstance(z, int), "z not of type int"
return z * y
# Use manual checks
foo.__wrapped__(3, 4.2)
# Ensure passes
with does_not_raise():
foo(3, 4.2)
# Check that argument mismatches also raise
with pytest.raises(TypeError, match="t not of type float"):
foo(3, 5)
@typecheck
def foo1(x: int, t: int) -> float:
y: float = x * t # Expect to fail here
z: int = x // 2
return z * y
with does_not_raise():
foo1.__wrapped__(3, 5)
with pytest.raises(TypeError, match="y not of type float"):
foo1(3, 5)
z_in_global_scope = 9
def test_typecheck_scope():
@typecheck
def foo2(x: int, t: float = 4.2) -> float:
return x * t * z_in_global_scope
foo2.__wrapped__(3)
with does_not_raise():
foo2(3)
z_in_outer_scope = 8
@typecheck
def foo2(x: int, t: float = 4.2) -> float:
return x * t * z_in_outer_scope
foo2.__wrapped__(3)
with does_not_raise():
foo2(3)
def test_typecheck_jax():
try:
import jax
except:
pytest.skip("No jax")
import jax
import jax.numpy as jnp
@typecheck_show_src
def foo1(x: jnp.ndarray, t: jnp.ndarray) -> float:
y: jnp.ndarray = x * t
z: jnp.ndarray = y / 2
return z
print(f"{isinstance(3, jnp.ndarray)=}")
float_array = jnp.ones((3, 5))
with pytest.raises(TypeError, match="x not of type jnp.ndarray"):
foo1(3, float_array)
# Jitted, it will not raise, as the tracers are of type jnp.ndarray
with does_not_raise():
jax.jit(foo1)(3, float_array)
def test_typecheck_jaxtyping1():
try:
import jax
import jaxtyping
except:
pytest.skip("No jax or jaxtyping")
import jax
from jaxtyping import f32, jaxtyped, u
# int_t = jaxtyping.i[""] TODO
rng = jax.random.PRNGKey(42)
vec_f32 = jax.random.uniform(rng, (11,))
@jax.jit
@partial(typecheck, show_src=True, refers=(jaxtyping,))
def foo1(x: jaxtyping.i[""], t: f32["N"]) -> f32["N"]:
z: f32["N"] = x * t
return z
with does_not_raise():
foo1(3, vec_f32)
def test_typecheck_jaxtyping2():
try:
import jax
import jaxtyping
except:
pytest.skip("No jaxtyping")
from jaxtyping import f32, jaxtyped
rng = jax.random.PRNGKey(42)
vec_f32 = jax.random.uniform(rng, (11,))
# Raw jaxtyped - won't check the statement annotation
@jaxtyped
def standardize(x: jaxtyping.f32["N"], eps=1e-5) -> f32["N"]:
m: float = x.mean()
xc: f32["N N"] = x - m # Wants to be NxN, won't be caught
return xc / (x.std() + eps)
with does_not_raise():
t1 = standardize(vec_f32)
# Typecheck with jaxtyping types - will raise
@partial(typecheck, show_src=True)
def standardize_tc(x: jaxtyping.f32["N"], eps=1e-5) -> f32["N"]:
m: jaxtyping.f32[""] = x.mean()
xc: f32["N N"] = x - m # Wants to be NxN, won't be caught
return xc / (x.std() + eps)
with pytest.raises(TypeError, match=r"xc not of type f32\['N N'\]"):
t1 = standardize_tc(vec_f32)
# embeddings = jax.random.uniform(rng, (11,13))
# t1 = standardize(embeddings)
# embeddings = jax.random.uniform(rng, (11, 13))
# t1 = jax.vmap(standardize)(embeddings)