Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 11, 2023
1 parent 8a74f79 commit 702c51d
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 31 deletions.
14 changes: 14 additions & 0 deletions tessellate_ipu/core/vertex/intrinsics_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
// #define ALWAYS_INLINE __attribute__((always_inline))
#define ALWAYS_INLINE inline

/**
* Tag dispatching, between IPU model and IPU hardware implementations.
*
* Making it hopefully easier to maintain IPU hardware and model implementations,
* without #ifdef/#endif preprocessor spaghetti code.
*/
namespace ipu
{
/** IPU hardware tag. */
struct HardwareTag{};
/** IPU model tag. */
struct ModelTag{};
}

#ifdef __IPU__

/**
Expand Down
6 changes: 3 additions & 3 deletions tessellate_ipu/core/vertex/tile_small_dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ class Rotation2dVertex : public MultiVertex {

static constexpr size_t MIN_ALIGN = 8;

Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
cs; // (2,) rotation cosinus/sinus values
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
inrow0; // (N,) first input row vector
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
inrow1; // (N,) second input row vector

Input<Vector<T, poplar::VectorLayout::ONE_PTR>>
cs; // (2,) rotation cosinus/sinus values

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) number threads + 1.
Expand All @@ -40,4 +40,4 @@ class Rotation2dVertex : public MultiVertex {

return true;
}
};
};
20 changes: 19 additions & 1 deletion tessellate_ipu/core/vertex/tile_small_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,22 @@
// #include <poplar/HalfFloat.hpp>
// #include <poplar/Vertex.hpp>

#include "intrinsics_utils.hpp"
#include "intrinsics_utils.hpp"

template <class R, class T>
R as(T x) {
return *reinterpret_cast<R *>(&x);
}


inline void rotation2_float(const float2 *inrow0, const float2 *inrow1,
float2 cs, float2 *outrow0, float2 *outrow1,
rptsize_t nblocks, ipu::HardwareTag) {}


inline void rotation2_float(const float2 *inrow0, const float2 *inrow1,
float2 cs, float2 *outrow0, float2 *outrow1,
rptsize_t nblocks, ipu::ModelTag) {


}
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p
from .tile_lax_small_dot import rotation2d_p
from .tile_lax_unary import ( # tanh_inplace_p,
abs_inplace_p,
asin_inplace_p,
Expand Down
73 changes: 46 additions & 27 deletions tessellate_ipu/lax/tile_lax_small_dot.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,49 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Tuple
import os
from typing import Any, Dict

import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import dot_general_p

from tessellate_ipu.core.tile_interpreter import register_ipu_tile_primitive
from tessellate_ipu.core.tile_interpreter_primitives import (
IpuTileMapEquation,
from_numpy_dtype_to_ipu_type,
make_ipu_vertex_constant_info,
make_ipu_vertex_in_info,
make_ipu_vertex_out_info,
tile_map_remove_ipu_attributes,
)
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_num_elements_per_worker

# from tessellate_ipu.lib.pytessellate_ipu_core import ( # noqa: F401
# IpuVertexAttributeI32,
# ipuGetTransformedInRowStride,
# ipuGetTransformedInStride,
# ipuGetTransformedOutStride,
# ipuReverseTransformedInRowStride,
# ipuReverseTransformedInStride,
# ipuReverseTransformedOutStride,
# )
from tessellate_ipu.utils import DType, DTypeLike, NDArray
from jax.core import ShapedArray

from tessellate_ipu.core import declare_ipu_tile_primitive
from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets

# from tessellate_ipu.utils import DType, DTypeLike, NDArray


def get_small_dot_vertex_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_small_dot.cpp")


@declare_ipu_tile_primitive("Rotation2dVertex", gp_filename=get_small_dot_vertex_gp_filename())
def rotation2d_p(cs: ShapedArray, inrow0: ShapedArray, inrow1: ShapedArray):
"""2d rotation apply primitive.
Specific optimization on IPU backend compared to `dot_general_p` primitive.
In particular, allows passing the 2 rows of the (2, N) input as separate arrays (in some
applications, contiguous storage may not be possible).
Args:
cs: Cos/sin 2d rotation entries.
inrow0: First row (N,)
inrow1: Second row (N,)
Returns:
outrow0: First output row (N,)
outrow1: Second output row (N,)
"""
N = inrow0.size
assert inrow0 == inrow1
assert cs.dtype == inrow0.dtype
assert cs.dtype == inrow1.dtype
assert inrow0.dtype == np.float32

outputs = {
"outrow0": inrow0,
"outrow1": inrow1,
}
constants = {
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(N, vector_size=2, wdtype=np.uint16)
}
temps: Dict[str, Any] = {}
perf_estimate = 100
return outputs, constants, temps, perf_estimate
12 changes: 12 additions & 0 deletions tests/lax/test_tile_lax_small_dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import TileShardedArray, tile_map, tile_put_sharded
from tessellate_ipu.lax.tile_lax_small_dot import rotation2d_p

0 comments on commit 702c51d

Please sign in to comment.