Skip to content

Commit

Permalink
Basic support of reshape in TessellateIPU tile_map. (#32)
Browse files Browse the repository at this point in the history
Note: not supporting the `dimensions` argument which implies a transpose additionally to the reshape.
  • Loading branch information
balancap authored Sep 27, 2023
1 parent e149c98 commit 1ca7c92
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 5 deletions.
6 changes: 3 additions & 3 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
| `bessel_i0e` | :x: | :x: | |
| `bessel_i1e` | :x: | :x: | |
| `betainc` | :x: | :x: | |
| `bitcase_convert_type` | :white_check_mark: | :question: | |
| `bitcast_convert_type` | :white_check_mark: | :white_check_mark: | Only same size dtype supported. |
| `bitwise_not` | :white_check_mark: | :x: | |
| `bitwise_and` | :white_check_mark: | :x: | |
| `bitwise_or` | :white_check_mark: | :x: | |
Expand Down Expand Up @@ -90,7 +90,7 @@
| `real` | :x: | :x: | |
| `reciprocal` | :white_check_mark: | :x: | |
| `reduce` | :white_check_mark: | :x: | |
| `reshape` | :x: | :x: | |
| `reshape` | :white_check_mark: | :white_check_mark: | `dimensions` argument not supported. |
| `rem` | :white_check_mark: | :white_check_mark: | |
| `rev` | :x: | :x: | |
| `round` | :white_check_mark: | :white_check_mark: | |
Expand All @@ -113,7 +113,7 @@
| `sort_key_val` | :x: | :x: | |
| `sqrt` | :white_check_mark: | :white_check_mark: | |
| `square` | :white_check_mark: | :x: | |
| `squeeze` | :white_check_mark: | :x: | |
| `squeeze` | :white_check_mark: | :white_check_mark: | |
| `sub` | :white_check_mark: | :white_check_mark: | |
| `tan` | :white_check_mark: | :white_check_mark: | |
| `tie_in` | :x: | :x: | Deprecated in JAX |
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, reshape_p
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
97 changes: 97 additions & 0 deletions tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple

from jax.core import Primitive, ShapedArray
from jax.lax import bitcast_convert_type_p, reshape_p

from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive


def ipu_reshape_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `reshape` LAX primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
new_sizes = attributes["new_sizes"]
dimensions = attributes.get("dimensions", None)
if dimensions is not None:
raise NotImplementedError("TessellateIPU `reshape` does not support a custom `dimensions` argument.")

outaval = ShapedArray(new_sizes, dtype=inaval.dtype, weak_type=inaval.dtype)
# Empty vertex name trick => identity function with inout argument, just doing reshaping.
vname = ""
inputs_info = [
make_ipu_vertex_inout_info("x", inaval),
]
outputs_info = [make_ipu_vertex_inout_info("x", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX reshape primitive.
register_ipu_tile_primitive(reshape_p, ipu_reshape_primitive_translation)


def ipu_bitcast_convert_type_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `bitcast_convert_type` LAX primitive translation rule to IPU vertex.
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
new_dtype = attributes["new_dtype"]
outaval = ShapedArray(inaval.shape, dtype=new_dtype, weak_type=inaval.dtype)
# Empty vertex name trick => identity function with inout argument, just doing reshaping.
vname = ""
inputs_info = [
make_ipu_vertex_inout_info("x", inaval),
]
outputs_info = [make_ipu_vertex_inout_info("x", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX bitcast_convert_type_p primitive.
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation)
24 changes: 22 additions & 2 deletions tessellate_ipu/lib/tile_map_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,29 @@ std::vector<poplar::Tensor> TileMapEquation::add(
const poplar::DebugContext& debug_prefix) const {
// All input tensors: i.e. add constant tensors.
const auto inputs_all = this->allocateInputTensors(graph, inputs);
// No vertex => assume identity function, i.e. forward inputs.

// No vertex => assume identity function.
// Forwarding inputs, with just potential change of shape and dtype.
if (this->vname.empty()) {
return inputs_all;
// Check inputs/outputs consistent.
if (this->numInputs() != this->numOutputs()) {
throw std::logic_error(
"Inconsistent number of inputs/outputs for an identity function.");
}
// Generate output tensors (potential reshaping + change of dtype).
std::vector<poplar::Tensor> outputs_all;
outputs_all.reserve(inputs_all.size());
for (size_t idx = 0; idx < inputs_all.size(); ++idx) {
const auto& in = inputs_all[idx];
const auto& outinfo = outputs_info[idx];
const auto outshape = shapePrependAxis(tiles.size(), outinfo.aval.shape);
const auto outdtype = toPoplar(outinfo.aval.dtype);
auto out = in.reshape(outshape).reinterpret(outdtype);
outputs_all.push_back(out);
}
return outputs_all;
}
// Usual path => map a vertex.
const auto outputs = this->allocateOutputTensors(graph, inputs);
this->add(graph, prog, inputs_all, outputs, debug_prefix);
return outputs;
Expand All @@ -147,6 +166,7 @@ std::size_t TileMapEquation::numInOuts() const {
throw std::logic_error(
"Inconsistent number of in/outs in the IPU tile map equation.");
}
// TODO: add checking on tensor size (not necessarily shape).
return num_inouts0;
}

Expand Down
56 changes: 56 additions & 0 deletions tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p


class IpuTileArrayPrimitiveTests(chex.TestCase):
def setUp(self):
super().setUp()
np.random.seed(42)

def test__tile_map__reshape__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.float32
inshape = (len(tiles), 6, 4)
indata = np.random.randn(*inshape).astype(dtype)

def compute_fn(input):
input = tile_put_sharded(input, tiles)
return tile_map(reshape_p, input, new_sizes=(3, 8), dimensions=None)

compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

output_cpu = compute_fn_cpu(indata)
output_ipu = compute_fn_ipu(indata)
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == indata.dtype
npt.assert_array_equal(output_ipu, output_cpu)

def test__tile_map__bitcast_convert_type__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.float32
inshape = (len(tiles), 6, 4)
indata = np.random.randn(*inshape).astype(dtype)

def compute_fn(input):
input = tile_put_sharded(input, tiles)
return tile_map(bitcast_convert_type_p, input, new_dtype=np.int32)

compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

output_cpu = compute_fn_cpu(indata)
output_ipu = compute_fn_ipu(indata)
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)

0 comments on commit 1ca7c92

Please sign in to comment.