Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support select operation in TessellateIPU #28

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
| `bitwise_and` | :white_check_mark: | :x: | |
| `bitwise_or` | :white_check_mark: | :x: | |
| `bitwise_xor` | :white_check_mark: | :x: | |
| `population_count` | :x: | :x: | |
| `population_count` | :white_check_mark: | :x: | |
| `broadcast` | :x: | :x: | |
| `broadcast_in_dim` | :x: | :x: | |
| `cbrt` | :white_check_mark: | :white_check_mark: | |
Expand Down Expand Up @@ -100,7 +100,7 @@
| `scatter_max` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_min` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `scatter_mul` | :white_check_mark: | :x: | Limited set of configurations. See below. |
| `select` | :x: | :x: | |
| `select` | :white_check_mark: | :x: | |
| `shift_left` | :white_check_mark: | :x: | |
| `shift_right_arithmetic`| :white_check_mark: | :x: | |
| `shift_right_logical` | :white_check_mark: | :x: | |
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
population_count_p,
round_p,
rsqrt_p,
select_n_p,
sign_p,
sin_p,
sqrt_p,
Expand Down
49 changes: 49 additions & 0 deletions tessellate_ipu/lax/tile_lax_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
make_ipu_vertex_attributes,
make_ipu_vertex_in_info,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
primitive_clone,
register_ipu_tile_primitive,
Expand Down Expand Up @@ -278,3 +279,51 @@ def register_ipu_binary_inplace_tile_primitive(orig_prim):
pow_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.pow_p)
rem_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.rem_p)
sub_inplace_p = register_ipu_binary_inplace_tile_primitive(lax.sub_p)


def ipu_select_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU select_n LAX primitive translation rule to IPU vertex.

Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 3
cond, x, y = inavals
# A couple of initial checks!
assert cond.shape == x.shape
assert cond.shape == y.shape
assert cond.dtype == np.bool_
assert x.dtype == y.dtype

vname = make_ipu_vertex_name_templated("popops::Select", x.dtype)
# Note: using `vertex_dim2=1` as Select vertex expecting vector of vector.
inputs_info = [
make_ipu_vertex_in_info("in3", cond, vertex_dim2=1),
make_ipu_vertex_in_info("in1", x, vertex_dim2=1),
make_ipu_vertex_in_info("in2", y, vertex_dim2=1),
]
outputs_info = [make_ipu_vertex_out_info("out", x, vertex_dim2=1)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX select primitive.
register_ipu_tile_primitive(lax.select_n_p, ipu_select_primitive_translation)
22 changes: 22 additions & 0 deletions tests/lax/test_tile_lax_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,28 @@ def compute_fn(A, B, sB):
assert output.dtype == A.dtype
npt.assert_array_almost_equal(output.array, scale_op_p.impl(A, B, sB), decimal=2)

@parameterized.parameters([np.float32])
def test__tile_map__select__ipu_jitting__proper_result(self, dtype):
tiles = (3, 4, 5)
inshape = (len(tiles), 7, 9)
mask = np.random.rand(*inshape) >= 0.5
input0 = np.random.randn(*inshape).astype(dtype)
input1 = np.random.randn(*inshape).astype(dtype)

@partial(jax.jit, backend="ipu")
def compute_fn(mask, in0, in1):
mask = tile_put_sharded(mask, tiles)
input0 = tile_put_sharded(in0, tiles)
input1 = tile_put_sharded(in1, tiles)
output = tile_map(lax.select_n_p, mask, input0, input1)
return output

output = compute_fn(mask, input0, input1)
assert isinstance(output, TileShardedArray)
assert output.tiles == tiles
assert output.dtype == input0.dtype
npt.assert_array_almost_equal(output.array, np.where(mask, input0, input1))


class IpuTileShiftPrimitivesTests(chex.TestCase):
def setUp(self):
Expand Down