diff --git a/tessellate_ipu/lax/tile_lax_scatter.py b/tessellate_ipu/lax/tile_lax_scatter.py new file mode 100644 index 0000000..6ecb165 --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_scatter.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Any, Dict, List, Tuple + +import numpy as np +from jax.core import Primitive, ShapedArray +from jax.lax import ( + GatherScatterMode, + ScatterDimensionNumbers, + scatter_add_p, + scatter_max_p, + scatter_min_p, + scatter_mul_p, +) + +from tessellate_ipu.core import ( + IpuTileMapEquation, + make_ipu_vertex_attributes, + make_ipu_vertex_in_info, + make_ipu_vertex_name_templated, + make_ipu_vertex_out_info, + register_ipu_tile_primitive, +) +from tessellate_ipu.utils import DTypeLike + + +def make_gather_vertex_fullname(dtype: DTypeLike) -> str: + """Generate popops Gather/MultiSlice vertex name.""" + basename = "popops::MultiSlice" + return make_ipu_vertex_name_templated(basename, dtype)