-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
95 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,49 @@ | ||
# Copyright (c) 2022 Graphcore Ltd. All rights reserved. | ||
from dataclasses import dataclass | ||
from enum import IntEnum | ||
from typing import Any, Dict, List, Tuple | ||
import os | ||
from typing import Any, Dict | ||
|
||
import numpy as np | ||
from jax.core import Primitive, ShapedArray | ||
from jax.lax import dot_general_p | ||
|
||
from tessellate_ipu.core.tile_interpreter import register_ipu_tile_primitive | ||
from tessellate_ipu.core.tile_interpreter_primitives import ( | ||
IpuTileMapEquation, | ||
from_numpy_dtype_to_ipu_type, | ||
make_ipu_vertex_constant_info, | ||
make_ipu_vertex_in_info, | ||
make_ipu_vertex_out_info, | ||
tile_map_remove_ipu_attributes, | ||
) | ||
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_num_elements_per_worker | ||
|
||
# from tessellate_ipu.lib.pytessellate_ipu_core import ( # noqa: F401 | ||
# IpuVertexAttributeI32, | ||
# ipuGetTransformedInRowStride, | ||
# ipuGetTransformedInStride, | ||
# ipuGetTransformedOutStride, | ||
# ipuReverseTransformedInRowStride, | ||
# ipuReverseTransformedInStride, | ||
# ipuReverseTransformedOutStride, | ||
# ) | ||
from tessellate_ipu.utils import DType, DTypeLike, NDArray | ||
from jax.core import ShapedArray | ||
|
||
from tessellate_ipu.core import declare_ipu_tile_primitive | ||
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets | ||
|
||
# from tessellate_ipu.utils import DType, DTypeLike, NDArray | ||
|
||
|
||
def get_small_dot_vertex_gp_filename() -> str: | ||
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_small_dot.cpp") | ||
|
||
|
||
@declare_ipu_tile_primitive("Rotation2dVertex", gp_filename=get_small_dot_vertex_gp_filename()) | ||
def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray): | ||
"""2d rotation apply primitive. | ||
Specific optimization on IPU backend compared to `dot_general_p` primitive. | ||
In particular, allows passing the 2 rows of the (2, N) input as separate arrays (in some | ||
applications, contiguous storage may not be possible). | ||
Args: | ||
cs: Cos/sin 2d rotation entries. | ||
inrow0: First row (N,) | ||
inrow1: Second row (N,) | ||
Returns: | ||
outrow0: First output row (N,) | ||
outrow1: Second output row (N,) | ||
""" | ||
N = inrow0.size | ||
assert inrow0 == inrow1 | ||
assert cs.dtype == inrow0.dtype | ||
assert cs.dtype == inrow1.dtype | ||
assert inrow0.dtype == np.float32 | ||
|
||
outputs = { | ||
"outrow0": inrow0, | ||
"outrow1": inrow1, | ||
} | ||
constants = { | ||
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16) | ||
} | ||
temps: Dict[str, Any] = {} | ||
perf_estimate = 100 | ||
return outputs, constants, temps, perf_estimate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import chex | ||
import jax | ||
import numpy as np | ||
import numpy.testing as npt | ||
from absl.testing import parameterized | ||
|
||
from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded | ||
from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p | ||
|