Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568545385
  • Loading branch information
Jake VanderPlas authored and The diffren Authors committed Sep 26, 2023
1 parent 22927a8 commit af686fb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion diffren/jax/internal/framebuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __post_init__(self):

try:
for i in range(1, len(shapes)):
if not jnp.array_equal(shapes[0][:-1], shapes[i][:-1]):
if not jnp.array_equal(shapes[0][:-1], shapes[i][:-1]): # pytype: disable=wrong-arg-types # jnp-type
raise ValueError(
f"Expected all input shapes to match (up to channels), "
f"but found {shapes}")
Expand Down

0 comments on commit af686fb

Please sign in to comment.