Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Sep 28, 2023
1 parent 91a5547 commit 2482e33
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, reshape_p
from .tile_lax_array import bitcast_convert_type_p, fill_p, reshape_p
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
3 changes: 2 additions & 1 deletion tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
make_ipu_vertex_attributes,
make_ipu_vertex_inout_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
)
from tessellate_ipu.utils import DTypeLike
Expand Down Expand Up @@ -156,7 +157,7 @@ def ipu_fill_primitive_translation_ipu(
pname=p.name,
tiles=tiles,
inputs_info=[],
outputs_info=[make_ipu_vertex_inout_info("out", outaval)],
outputs_info=[make_ipu_vertex_out_info("out", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
Expand Down
20 changes: 19 additions & 1 deletion tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.testing as npt

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p
from tessellate_ipu.lax import bitcast_convert_type_p, fill_p, reshape_p


class IpuTileArrayPrimitiveTests(chex.TestCase):
Expand Down Expand Up @@ -54,3 +54,21 @@ def compute_fn(input):
assert output_ipu.tiles == tiles
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)

def test__tile_map__fill__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.int32
# inshape = (len(tiles), 6, 4)

def compute_fn():
return tile_map(fill_p, shape=(4, 5), fill_value=1, dtype=dtype, tiles=tiles)

compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)
output_ipu = compute_fn_ipu()
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
# assert output_ipu.dtype == np.int32

print(output_ipu)
assert False
# npt.assert_array_equal(output_ipu, output_cpu)

0 comments on commit 2482e33

Please sign in to comment.