-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Basic support of
reshape
in TessellateIPU tile_map. (#32)
Note: not supporting the `dimensions` argument which implies a transpose additionally to the reshape.
- Loading branch information
Showing
5 changed files
with
179 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from jax.core import Primitive, ShapedArray | ||
from jax.lax import bitcast_convert_type_p, reshape_p | ||
|
||
from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive | ||
|
||
|
||
def ipu_reshape_primitive_translation( | ||
p: Primitive, | ||
tiles: Tuple[int, ...], | ||
inavals: List[ShapedArray], | ||
attributes: Dict[str, Any] = None, | ||
) -> IpuTileMapEquation: | ||
"""IPU `reshape` LAX primitive translation rule to IPU vertex. | ||
Args: | ||
p: JAX primitive. | ||
tiles: Collection of tiles. | ||
inavals: Input shaped arrays. | ||
attributes: (unused) attributes. | ||
Returns: | ||
IPU tile map primitive structure. | ||
""" | ||
assert len(inavals) == 1 | ||
assert attributes is not None | ||
inaval = inavals[0] | ||
new_sizes = attributes["new_sizes"] | ||
dimensions = attributes.get("dimensions", None) | ||
if dimensions is not None: | ||
raise NotImplementedError("TessellateIPU `reshape` does not support a custom `dimensions` argument.") | ||
|
||
outaval = ShapedArray(new_sizes, dtype=inaval.dtype, weak_type=inaval.dtype) | ||
# Empty vertex name trick => identity function with inout argument, just doing reshaping. | ||
vname = "" | ||
inputs_info = [ | ||
make_ipu_vertex_inout_info("x", inaval), | ||
] | ||
outputs_info = [make_ipu_vertex_inout_info("x", outaval)] | ||
ipu_prim_info = IpuTileMapEquation( | ||
vname=vname, | ||
pname=p.name, | ||
tiles=tiles, | ||
inputs_info=inputs_info, | ||
outputs_info=outputs_info, | ||
attributes_i32=[], | ||
attributes_f32=[], | ||
) | ||
return ipu_prim_info | ||
|
||
|
||
# Register JAX LAX reshape primitive. | ||
register_ipu_tile_primitive(reshape_p, ipu_reshape_primitive_translation) | ||
|
||
|
||
def ipu_bitcast_convert_type_primitive_translation( | ||
p: Primitive, | ||
tiles: Tuple[int, ...], | ||
inavals: List[ShapedArray], | ||
attributes: Dict[str, Any] = None, | ||
) -> IpuTileMapEquation: | ||
"""IPU `bitcast_convert_type` LAX primitive translation rule to IPU vertex. | ||
Args: | ||
p: JAX primitive. | ||
tiles: Collection of tiles. | ||
inavals: Input shaped arrays. | ||
attributes: (unused) attributes. | ||
Returns: | ||
IPU tile map primitive structure. | ||
""" | ||
assert len(inavals) == 1 | ||
assert attributes is not None | ||
inaval = inavals[0] | ||
new_dtype = attributes["new_dtype"] | ||
outaval = ShapedArray(inaval.shape, dtype=new_dtype, weak_type=inaval.dtype) | ||
# Empty vertex name trick => identity function with inout argument, just doing reshaping. | ||
vname = "" | ||
inputs_info = [ | ||
make_ipu_vertex_inout_info("x", inaval), | ||
] | ||
outputs_info = [make_ipu_vertex_inout_info("x", outaval)] | ||
ipu_prim_info = IpuTileMapEquation( | ||
vname=vname, | ||
pname=p.name, | ||
tiles=tiles, | ||
inputs_info=inputs_info, | ||
outputs_info=outputs_info, | ||
attributes_i32=[], | ||
attributes_f32=[], | ||
) | ||
return ipu_prim_info | ||
|
||
|
||
# Register JAX LAX bitcast_convert_type_p primitive. | ||
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (c) 2022 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import chex | ||
import jax | ||
import numpy as np | ||
import numpy.testing as npt | ||
|
||
from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded | ||
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p | ||
|
||
|
||
class IpuTileArrayPrimitiveTests(chex.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
np.random.seed(42) | ||
|
||
def test__tile_map__reshape__ipu_jitting__proper_result(self): | ||
tiles = (3, 4, 5) | ||
dtype = np.float32 | ||
inshape = (len(tiles), 6, 4) | ||
indata = np.random.randn(*inshape).astype(dtype) | ||
|
||
def compute_fn(input): | ||
input = tile_put_sharded(input, tiles) | ||
return tile_map(reshape_p, input, new_sizes=(3, 8), dimensions=None) | ||
|
||
compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) | ||
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) | ||
|
||
output_cpu = compute_fn_cpu(indata) | ||
output_ipu = compute_fn_ipu(indata) | ||
assert isinstance(output_ipu, TileShardedArray) | ||
assert output_ipu.tiles == tiles | ||
assert output_ipu.dtype == indata.dtype | ||
npt.assert_array_equal(output_ipu, output_cpu) | ||
|
||
def test__tile_map__bitcast_convert_type__ipu_jitting__proper_result(self): | ||
tiles = (3, 4, 5) | ||
dtype = np.float32 | ||
inshape = (len(tiles), 6, 4) | ||
indata = np.random.randn(*inshape).astype(dtype) | ||
|
||
def compute_fn(input): | ||
input = tile_put_sharded(input, tiles) | ||
return tile_map(bitcast_convert_type_p, input, new_dtype=np.int32) | ||
|
||
compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn) | ||
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) | ||
|
||
output_cpu = compute_fn_cpu(indata) | ||
output_ipu = compute_fn_ipu(indata) | ||
assert isinstance(output_ipu, TileShardedArray) | ||
assert output_ipu.tiles == tiles | ||
assert output_ipu.dtype == np.int32 | ||
npt.assert_array_equal(output_ipu, output_cpu) |