Skip to content

Commit

Permalink
JBAI-5829 [examples] Added support for GemmVer9 and refactor existing…
Browse files Browse the repository at this point in the history
… Gemm logic.
  • Loading branch information
dmitriyb committed Sep 27, 2024
1 parent a38c555 commit 8fc7f7e
Showing 1 changed file with 84 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,65 @@ import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto

sealed class Gemm(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
private val alpha: Double by attribute { it: Number -> it.toDouble() }
private val beta: Double by attribute { it: Number -> it.toDouble() }

private val transA: Boolean by attribute { it: Number -> it.toInt() != 0 }
private val transB: Boolean by attribute { it: Number -> it.toInt() != 0 }

companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) = when (version ?: DEFAULT_VERSION.sinceVersion) {
in GemmVer9.VERSION.asRange() -> GemmVer9(name, attributes, inputs, outputs)
in GemmVer11.VERSION.asRange() -> GemmVer11(name, attributes, inputs, outputs)
else -> error("Unsupported version of Gemm operator: $version")
}
}
}

protected suspend fun getDest(array: NDArrayCore, type: DataType, targetShape: IntArray): MutableNDArrayCore {
if (array.shape.contentEquals(targetShape)) return array.toMutable()

class GemmVer11(name: String, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Gemm(name, INFO, attributes, inputs, outputs) {
private val alpha: Double by attribute { it: Number -> it.toDouble() }
private val beta: Double by attribute { it: Number -> it.toDouble() }
val dstArray = allocateNDArray(type, Strides(targetShape)) as MutableNumberNDArrayCore
val unsqueezedShape = unsqueezeFirst(array.shape, targetShape.size)

private val transA: Boolean by attribute { it: Number -> it.toInt() != 0 }
private val transB: Boolean by attribute { it: Number -> it.toInt() != 0 }
if (targetShape[1] != unsqueezedShape[1] && unsqueezedShape[1] == 1) {
val targetBlockSize = targetShape[1]
for (i in 0 until unsqueezedShape[0]) {
val dstOffsetBase = i * targetBlockSize
dstArray.fillByArrayValue(array, i, dstOffsetBase, dstOffsetBase + targetBlockSize)
}
} else {
dstArray.copyFrom(0, array)
}

for (i in 1 until targetShape[0]) dstArray.copyFrom(i * targetShape[1], dstArray, 0, targetShape[1])
return dstArray
}

protected suspend fun <D : ONNXData<*, *>> apply(inputs: List<KITensor?>, optionalBias: Boolean): List<KITensor?> {
val a = inputs[0]!!.data as NumberNDArrayCore
val b = inputs[1]!!.data as NumberNDArrayCore

val m = if (!transA) a.shape[0] else a.shape[1]
val n = if (!transB) b.shape[1] else b.shape[0]
val k = if (!transA) a.shape[1] else a.shape[0]

val targetShape = intArrayOf(m, n)
val bias = if (optionalBias) {
inputs.getOrNull(2)?.data ?: allocateNDArray(a.type, targetShape)
} else {
inputs[2]!!.data
} as NumberNDArrayCore

val c = getDest(bias, a.type, intArrayOf(m, n))
gemm(m, n, k, alpha, a, b, beta, c, transposeA = transA, transposeB = transB)

return listOf(c.asTensor())
}
}

class GemmVer9(name: String, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Gemm(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = setOf(
TensorProto.DataType.FLOAT16,
Expand All @@ -55,47 +96,53 @@ class GemmVer11(name: String, attributes: Map<String, Attribute<Any>>, inputs: L
private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false),
IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false),
IOInfo(2, TYPE_CONSTRAINTS, "C", optional = true)
IOInfo(2, TYPE_CONSTRAINTS, "C", optional = false)
)

private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 11)
internal val VERSION = VersionInfo(sinceVersion = 9, untilVersion = 11)
private val INFO = OperatorInfo("Gemm", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)

private suspend fun getDest(array: NDArrayCore?, type: DataType, targetShape: IntArray): MutableNDArrayCore {
if (array == null) return allocateNDArray(type, Strides(targetShape))
if (array.shape.contentEquals(targetShape)) return array.toMutable()

val dstArray = allocateNDArray(type, Strides(targetShape)) as MutableNumberNDArrayCore
val unsqueezedShape = unsqueezeFirst(array.shape, targetShape.size)

if (targetShape[1] != unsqueezedShape[1] && unsqueezedShape[1] == 1) {
val targetBlockSize = targetShape[1]
for (i in 0 until unsqueezedShape[0]) {
val dstOffsetBase = i * targetBlockSize
dstArray.fillByArrayValue(array, i, dstOffsetBase, dstOffsetBase + targetBlockSize)
}
} else {
dstArray.copyFrom(0, array)
}

for (i in 1 until targetShape[0]) dstArray.copyFrom(i * targetShape[1], dstArray, 0, targetShape[1])
return dstArray
}
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val a = inputs[0]!!.data as NumberNDArrayCore
val b = inputs[1]!!.data as NumberNDArrayCore
return apply<ONNXData<*, *>>(inputs, INPUTS_INFO[2].optional)
}
}

val m = if (!transA) a.shape[0] else a.shape[1]
val n = if (!transB) b.shape[1] else b.shape[0]
val k = if (!transA) a.shape[1] else a.shape[0]
class GemmVer11(name: String, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Gemm(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = setOf(
TensorProto.DataType.FLOAT16,
TensorProto.DataType.FLOAT,
TensorProto.DataType.DOUBLE,
TensorProto.DataType.UINT32,
TensorProto.DataType.UINT64,
TensorProto.DataType.INT32,
TensorProto.DataType.INT64,
TensorProto.DataType.BFLOAT16
)

val c = getDest(inputs.getOrNull(2)?.data, a.type, intArrayOf(m, n))
gemm(m, n, k, alpha, a, b, beta, c, transposeA = transA, transposeB = transB)
private val ATTRIBUTES_INFO = listOf(
AttributeInfo("alpha", setOf(AttributeProto.AttributeType.FLOAT), false, 1.0),
AttributeInfo("beta", setOf(AttributeProto.AttributeType.FLOAT), false, 1.0),
AttributeInfo("transA", setOf(AttributeProto.AttributeType.INT), false, 0),
AttributeInfo("transB", setOf(AttributeProto.AttributeType.INT), false, 0)
)

return listOf(c.asTensor())
private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false),
IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false),
IOInfo(2, TYPE_CONSTRAINTS, "C", optional = true)
)

private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 11)
private val INFO = OperatorInfo("Gemm", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
return apply<ONNXData<*, *>>(inputs, INPUTS_INFO[2].optional)
}
}

0 comments on commit 8fc7f7e

Please sign in to comment.