From 2482e3339aae4dac03d67d5ff8129ed844f90985 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 28 Sep 2023 20:51:39 +0000 Subject: [PATCH] wip --- tessellate_ipu/lax/__init__.py | 2 +- tessellate_ipu/lax/tile_lax_array.py | 3 ++- tests/lax/test_tile_lax_array.py | 20 +++++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tessellate_ipu/lax/__init__.py b/tessellate_ipu/lax/__init__.py index 58dbe27..42e42d3 100644 --- a/tessellate_ipu/lax/__init__.py +++ b/tessellate_ipu/lax/__init__.py @@ -25,7 +25,7 @@ ) from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random -from .tile_lax_array import bitcast_convert_type_p, reshape_p +from .tile_lax_array import bitcast_convert_type_p, fill_p, reshape_p from .tile_lax_binary import ( add_inplace_p, atan2_inplace_p, diff --git a/tessellate_ipu/lax/tile_lax_array.py b/tessellate_ipu/lax/tile_lax_array.py index 6614c6a..4eaa8f5 100644 --- a/tessellate_ipu/lax/tile_lax_array.py +++ b/tessellate_ipu/lax/tile_lax_array.py @@ -11,6 +11,7 @@ make_ipu_vertex_attributes, make_ipu_vertex_inout_info, make_ipu_vertex_name_templated, + make_ipu_vertex_out_info, register_ipu_tile_primitive, ) from tessellate_ipu.utils import DTypeLike @@ -156,7 +157,7 @@ def ipu_fill_primitive_translation_ipu( pname=p.name, tiles=tiles, inputs_info=[], - outputs_info=[make_ipu_vertex_inout_info("out", outaval)], + outputs_info=[make_ipu_vertex_out_info("out", outaval)], attributes_i32=attrs_i32, attributes_f32=attrs_f32, ) diff --git a/tests/lax/test_tile_lax_array.py b/tests/lax/test_tile_lax_array.py index 0cdb0ac..0030a50 100644 --- a/tests/lax/test_tile_lax_array.py +++ b/tests/lax/test_tile_lax_array.py @@ -7,7 +7,7 @@ import numpy.testing as npt from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded -from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p +from tessellate_ipu.lax import bitcast_convert_type_p, fill_p, reshape_p class IpuTileArrayPrimitiveTests(chex.TestCase): @@ -54,3 +54,21 @@ def compute_fn(input): assert output_ipu.tiles == tiles assert output_ipu.dtype == np.int32 npt.assert_array_equal(output_ipu, output_cpu) + + def test__tile_map__fill__ipu_jitting__proper_result(self): + tiles = (3, 4, 5) + dtype = np.int32 + # inshape = (len(tiles), 6, 4) + + def compute_fn(): + return tile_map(fill_p, shape=(4, 5), fill_value=1, dtype=dtype, tiles=tiles) + + compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn) + output_ipu = compute_fn_ipu() + assert isinstance(output_ipu, TileShardedArray) + assert output_ipu.tiles == tiles + # assert output_ipu.dtype == np.int32 + + print(output_ipu) + assert False + # npt.assert_array_equal(output_ipu, output_cpu)