Skip to content

Commit

Permalink
Fix Gather op
Browse files Browse the repository at this point in the history
  • Loading branch information
EmergentOrder committed Apr 21, 2021
1 parent 689ba2a commit 3ff58b3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
17 changes: 9 additions & 8 deletions core/src/main/scala/ONNX.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,20 +429,21 @@ package object onnx {
}
}
//Missing in NDScala - P3
//need a match type
trait GatherV13 extends Operator {
def GatherV13[
@sp T <: UByte | UShort | UInt | ULong | Byte | Short | Int | Long | BFloat16 | Float16 | Float | Double | String | Boolean | Complex[
Float
] | Complex[Double]: Numeric,
@sp Tind <: Int | Long: Numeric
, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, S2 <: Shape](
] | Complex[Double],
@sp Tind <: Int : Numeric, //Spec also supports long
Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape, Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape, Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, AxisIndex <: Index ::: INil, AxisIndices <: Indices](
name: String,
axis: Int = 0,
axis: AxisIndex = 0 ::: INil,
data: Tensor[T, Tuple3[Tt,Td,S]],
indices: Tensor[Tind, Tuple3[Tt1,Td1,S1]]
)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[S2]): Tensor[T, Tuple3[Tt2,Td2,S2]] = {
val map: Map[String, Any] = Map("axis" -> axis)
val allInputs = Tuple2(data, indices)
indices: AxisIndices
)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]], i: IndicesOf[AxisIndex], i2: IndicesOf[AxisIndices]): Tensor[T, Tuple3[Tt2,Td2,GatheredShape[S, AxisIndex, AxisIndices]]] = {
val map: Map[String, Any] = Map("axis" -> indicesOf[AxisIndex].indices.toArray.head)
val allInputs = Tuple2(data, Tensor(indicesOf[AxisIndices].indices.toArray, indicesOf[AxisIndices].indices.toArray.size.asInstanceOf[io.kjaer.compiletime.Dimension] #: SNil))
(callOp(name, "Gather", allInputs, map))
}
}
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ object Tensors{
}
}

type GatheredShape[S <: Shape, AxisIndex <: None.type | Indices, AxisIndices <: Indices] <: Shape = AxisIndex match {
case None.type => SNil
case Indices => GatheredShapeLoop[S, AxisIndex, 0, AxisIndices]
}

protected type GatheredShapeLoop[ToGather <: Shape, AxisIndex <: Indices, I <: Index, AxisIndices <: Indices] <: Shape = ToGather match {
case head #: tail => Indices.Contains[AxisIndex, I] match {
case true => IndicesSize[AxisIndices] #: GatheredShapeLoop[tail, Indices.RemoveValue[AxisIndex, I], S[I], AxisIndices]
case false => head #: GatheredShapeLoop[tail, AxisIndex, S[I], AxisIndices]
}
case SNil => AxisIndex match {
case INil => SNil
}
}

type IndicesSize[AxisIndices <: Indices] = IndicesSizeLoop[AxisIndices, 0]

type IndicesSizeLoop[AxisIndices <: Indices, Acc <: Dimension] = AxisIndices match {
case head ::: tail => IndicesSizeLoop[tail, S[Acc]]
case INil => Acc
}

type FlattenedShape[S <: Shape, AxisIndex <: None.type | Indices] <: Shape = AxisIndex match {
case None.type => SNil
Expand Down

0 comments on commit 3ff58b3

Please sign in to comment.