Skip to content

Commit

Permalink
Support reciprocal operation in TessellateIPU. (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap authored Sep 27, 2023
1 parent ad04978 commit e149c98
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
36 changes: 18 additions & 18 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,26 @@
| `le` | :white_check_mark: | :x: | |
| `lt` | :white_check_mark: | :x: | |
| `lgamma` | :x: | :x: | |
| `log` | :white_check_mark: | :x: | |
| `log1p` | :white_check_mark: | :x: | |
| `log` | :white_check_mark: | :white_check_mark: | |
| `log1p` | :white_check_mark: | :white_check_mark: | |
| `logistic` | :x: | :x: | |
| `max` | :white_check_mark: | :x: | |
| `min` | :white_check_mark: | :x: | |
| `mul` | :white_check_mark: | :x: | |
| `max` | :white_check_mark: | :white_check_mark: | |
| `min` | :white_check_mark: | :white_check_mark: | |
| `mul` | :white_check_mark: | :white_check_mark: | |
| `ne` | :white_check_mark: | :x: | |
| `neg` | :white_check_mark: | :x: | |
| `neg` | :white_check_mark: | :white_check_mark: | |
| `nextafter` | :x: | :x: | |
| `pad` | :x: | :x: | |
| `polygamma` | :x: | :x: | |
| `pow` | :white_check_mark: | :x: | |
| `pow` | :white_check_mark: | :white_check_mark: | |
| `real` | :x: | :x: | |
| `reciprocal` | :x: | :x: | |
| `reciprocal` | :white_check_mark: | :x: | |
| `reduce` | :white_check_mark: | :x: | |
| `reshape` | :white_check_mark: | :x: | |
| `rem` | :white_check_mark: | :x: | |
| `rev` | :white_check_mark: | :x: | |
| `round` | :white_check_mark: | :x: | |
| `rsqrt` | :white_check_mark: | :x: | |
| `reshape` | :x: | :x: | |
| `rem` | :white_check_mark: | :white_check_mark: | |
| `rev` | :x: | :x: | |
| `round` | :white_check_mark: | :white_check_mark: | |
| `rsqrt` | :white_check_mark: | :white_check_mark: | |
| `scatter` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_add` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_max` | :white_check_mark: | :x: | Limited set of configurations. See below. |
Expand All @@ -107,15 +107,15 @@
| `slice` | :white_check_mark: | :x: | |
| `slice_in_dim` | :white_check_mark: | :x: | |
| `sign` | :white_check_mark: | :x: | |
| `sin` | :white_check_mark: | :x: | |
| `sinh` | :white_check_mark: | :x: | |
| `sin` | :white_check_mark: | :white_check_mark: | |
| `sinh` | :x: | :x: | |
| `sort` | :x: | :x: | |
| `sort_key_val` | :x: | :x: | |
| `sqrt` | :white_check_mark: | :x: | |
| `sqrt` | :white_check_mark: | :white_check_mark: | |
| `square` | :white_check_mark: | :x: | |
| `squeeze` | :white_check_mark: | :x: | |
| `sub` | :white_check_mark: | :x: | |
| `tan` | :white_check_mark: | :x: | |
| `sub` | :white_check_mark: | :white_check_mark: | |
| `tan` | :white_check_mark: | :white_check_mark: | |
| `tie_in` | :x: | :x: | Deprecated in JAX |
| `top_k` | :x: | :x: | |
| `transpose` | :white_check_mark: | :x: | Copies the input tensor |
Expand Down
11 changes: 7 additions & 4 deletions tessellate_ipu/lax/tile_lax_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def ipu_integer_pow_translation(
) -> IpuTileMapEquation:
"""IPU `integer_pow` primitive translation rule to IPU vertex.
Only supporting -1 and 2 exponents at the moment.
Args:
p: JAX primitive.
tiles: Collection of tiles.
Expand All @@ -203,12 +205,13 @@ def ipu_integer_pow_translation(
assert attributes is not None
inaval = inavals[0]
pow = attributes["y"]
if pow != 2:
supported_powers = {-1: "INVERSE", 2: "SQUARE"}
if pow not in supported_powers:
# TODO: general vertex?
raise ValueError("Only supporting integer power of 2 on IPU tile primitives.")
raise ValueError(f"Only supporting integer powers '{tuple(supported_powers.keys())}' in TessellateIPU library.")

# IPU cast arguments.
vname = make_unary1d_vertex_fullname("SQUARE", inaval.dtype, inplace=False)
# Used proper vertex depending on the power!
vname = make_unary1d_vertex_fullname(supported_powers[pow], inaval.dtype, inplace=False)
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
Expand Down
3 changes: 2 additions & 1 deletion tests/lax/test_tile_lax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def compute_fn(input):

@parameterized.parameters(
[
(np.float32, 2),
(np.float32, -1), # reciprocal/inverse vertex
(np.float32, 2), # square vertex
(np.float16, 2),
(np.int32, 2),
]
Expand Down

0 comments on commit e149c98

Please sign in to comment.