Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support of reshape in TessellateIPU tile_map. #32

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
| `bessel_i0e` | :x: | :x: | |
| `bessel_i1e` | :x: | :x: | |
| `betainc` | :x: | :x: | |
| `bitcase_convert_type` | :white_check_mark: | :question: | |
| `bitcast_convert_type` | :white_check_mark: | :white_check_mark: | Only same size dtype supported. |
| `bitwise_not` | :white_check_mark: | :x: | |
| `bitwise_and` | :white_check_mark: | :x: | |
| `bitwise_or` | :white_check_mark: | :x: | |
Expand Down Expand Up @@ -90,7 +90,7 @@
| `real` | :x: | :x: | |
| `reciprocal` | :white_check_mark: | :x: | |
| `reduce` | :white_check_mark: | :x: | |
| `reshape` | :x: | :x: | |
| `reshape` | :white_check_mark: | :white_check_mark: | `dimensions` argument not supported. |
| `rem` | :white_check_mark: | :white_check_mark: | |
| `rev` | :x: | :x: | |
| `round` | :white_check_mark: | :white_check_mark: | |
Expand All @@ -113,7 +113,7 @@
| `sort_key_val` | :x: | :x: | |
| `sqrt` | :white_check_mark: | :white_check_mark: | |
| `square` | :white_check_mark: | :x: | |
| `squeeze` | :white_check_mark: | :x: | |
| `squeeze` | :white_check_mark: | :white_check_mark: | |
| `sub` | :white_check_mark: | :white_check_mark: | |
| `tan` | :white_check_mark: | :white_check_mark: | |
| `tie_in` | :x: | :x: | Deprecated in JAX |
Expand Down
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from . import tile_lax_binary, tile_lax_dot, tile_lax_reduce, tile_lax_unary, tile_random
from .tile_lax_array import bitcast_convert_type_p, reshape_p
from .tile_lax_binary import (
add_inplace_p,
atan2_inplace_p,
Expand Down
97 changes: 97 additions & 0 deletions tessellate_ipu/lax/tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple

from jax.core import Primitive, ShapedArray
from jax.lax import bitcast_convert_type_p, reshape_p

from tessellate_ipu.core import IpuTileMapEquation, make_ipu_vertex_inout_info, register_ipu_tile_primitive


def ipu_reshape_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `reshape` LAX primitive translation rule to IPU vertex.

Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
new_sizes = attributes["new_sizes"]
dimensions = attributes.get("dimensions", None)
if dimensions is not None:
raise NotImplementedError("TessellateIPU `reshape` does not support a custom `dimensions` argument.")

outaval = ShapedArray(new_sizes, dtype=inaval.dtype, weak_type=inaval.dtype)
# Empty vertex name trick => identity function with inout argument, just doing reshaping.
vname = ""
inputs_info = [
make_ipu_vertex_inout_info("x", inaval),
]
outputs_info = [make_ipu_vertex_inout_info("x", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX reshape primitive.
register_ipu_tile_primitive(reshape_p, ipu_reshape_primitive_translation)


def ipu_bitcast_convert_type_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `bitcast_convert_type` LAX primitive translation rule to IPU vertex.

Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input shaped arrays.
attributes: (unused) attributes.
Returns:
IPU tile map primitive structure.
"""
assert len(inavals) == 1
assert attributes is not None
inaval = inavals[0]
new_dtype = attributes["new_dtype"]
outaval = ShapedArray(inaval.shape, dtype=new_dtype, weak_type=inaval.dtype)
# Empty vertex name trick => identity function with inout argument, just doing reshaping.
vname = ""
inputs_info = [
make_ipu_vertex_inout_info("x", inaval),
]
outputs_info = [make_ipu_vertex_inout_info("x", outaval)]
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=inputs_info,
outputs_info=outputs_info,
attributes_i32=[],
attributes_f32=[],
)
return ipu_prim_info


# Register JAX LAX bitcast_convert_type_p primitive.
register_ipu_tile_primitive(bitcast_convert_type_p, ipu_bitcast_convert_type_primitive_translation)
24 changes: 22 additions & 2 deletions tessellate_ipu/lib/tile_map_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,29 @@ std::vector<poplar::Tensor> TileMapEquation::add(
const poplar::DebugContext& debug_prefix) const {
// All input tensors: i.e. add constant tensors.
const auto inputs_all = this->allocateInputTensors(graph, inputs);
// No vertex => assume identity function, i.e. forward inputs.

// No vertex => assume identity function.
// Forwarding inputs, with just potential change of shape and dtype.
if (this->vname.empty()) {
return inputs_all;
// Check inputs/outputs consistent.
if (this->numInputs() != this->numOutputs()) {
throw std::logic_error(
"Inconsistent number of inputs/outputs for an identity function.");
}
// Generate output tensors (potential reshaping + change of dtype).
std::vector<poplar::Tensor> outputs_all;
outputs_all.reserve(inputs_all.size());
for (size_t idx = 0; idx < inputs_all.size(); ++idx) {
const auto& in = inputs_all[idx];
const auto& outinfo = outputs_info[idx];
const auto outshape = shapePrependAxis(tiles.size(), outinfo.aval.shape);
const auto outdtype = toPoplar(outinfo.aval.dtype);
auto out = in.reshape(outshape).reinterpret(outdtype);
outputs_all.push_back(out);
}
return outputs_all;
}
// Usual path => map a vertex.
const auto outputs = this->allocateOutputTensors(graph, inputs);
this->add(graph, prog, inputs_all, outputs, debug_prefix);
return outputs;
Expand All @@ -147,6 +166,7 @@ std::size_t TileMapEquation::numInOuts() const {
throw std::logic_error(
"Inconsistent number of in/outs in the IPU tile map equation.");
}
// TODO: add checking on tensor size (not necessarily shape).
return num_inouts0;
}

Expand Down
56 changes: 56 additions & 0 deletions tests/lax/test_tile_lax_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax import bitcast_convert_type_p, reshape_p


class IpuTileArrayPrimitiveTests(chex.TestCase):
def setUp(self):
super().setUp()
np.random.seed(42)

def test__tile_map__reshape__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.float32
inshape = (len(tiles), 6, 4)
indata = np.random.randn(*inshape).astype(dtype)

def compute_fn(input):
input = tile_put_sharded(input, tiles)
return tile_map(reshape_p, input, new_sizes=(3, 8), dimensions=None)

compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

output_cpu = compute_fn_cpu(indata)
output_ipu = compute_fn_ipu(indata)
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == indata.dtype
npt.assert_array_equal(output_ipu, output_cpu)

def test__tile_map__bitcast_convert_type__ipu_jitting__proper_result(self):
tiles = (3, 4, 5)
dtype = np.float32
inshape = (len(tiles), 6, 4)
indata = np.random.randn(*inshape).astype(dtype)

def compute_fn(input):
input = tile_put_sharded(input, tiles)
return tile_map(bitcast_convert_type_p, input, new_dtype=np.int32)

compute_fn_cpu = partial(jax.jit, backend="cpu")(compute_fn)
compute_fn_ipu = partial(jax.jit, backend="ipu")(compute_fn)

output_cpu = compute_fn_cpu(indata)
output_ipu = compute_fn_ipu(indata)
assert isinstance(output_ipu, TileShardedArray)
assert output_ipu.tiles == tiles
assert output_ipu.dtype == np.int32
npt.assert_array_equal(output_ipu, output_cpu)