Skip to content

Commit

Permalink
JBAI-197 [ndarray] Optimized broadcasting functions by avoiding some …
Browse files Browse the repository at this point in the history
…allocations.
  • Loading branch information
dmitriyb committed Oct 15, 2024
1 parent cca1f0f commit 80e3106
Showing 1 changed file with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.kinference.primitives.types.DataType
// TODO remove to different module
fun unsqueezeFirst(shape: IntArray, newShapeSize: Int): IntArray {
val wrapSize = newShapeSize - shape.size
if (wrapSize == 0) return shape

val wrappedShape = IntArray(newShapeSize)
wrappedShape.fill(1, 0, wrapSize)
Expand Down Expand Up @@ -35,7 +36,11 @@ object Broadcasting {
destination: MutableNDArrayCore,
op: suspend (List<NDArrayCore>, MutableNDArrayCore) -> Unit
): MutableNDArrayCore {
val wrappedInputs = inputs.map { it.reshape(unsqueezeFirst(it.shape, destination.shape.size)) }
val destRank = destination.shape.size
val wrappedInputs = inputs.map { input ->
if (input.shape.size == destRank) input
else input.reshape(unsqueezeFirst(input.shape, destRank))
}

broadcast(wrappedInputs, destination, op)
return destination
Expand Down Expand Up @@ -101,29 +106,43 @@ object Broadcasting {
destination: MutableNDArrayCore,
recurrentBack: suspend (List<NDArrayCore>, MutableNDArrayCore) -> Unit
) {
val numInputs = inputs.size
val indexedInputs = inputs.withIndex()
val (arraysWithOne, arraysWithoutOne) = indexedInputs.partition { it.value.shape[0] == 1 }

val mergedInputs = MutableList<NDArrayCore>(numInputs) { inputs[0] }

if (destination.shape.size == 1) {
val broadcastSize = destination.shape.last()
val broadcastArraysWithOne = arraysWithOne.map {
val value = allocateNDArray(it.value.type, Strides(intArrayOf(broadcastSize)))
it.copy(value = value.apply { fill(it.value.singleValue()) })
val broadcastArraysWithOne = arraysWithOne.map { indexedInput ->
val value = allocateNDArray(indexedInput.value.type, Strides(intArrayOf(broadcastSize)))
value.apply { fill(indexedInput.value.singleValue()) }
}

arraysWithOne.forEachIndexed { i, indexedInput ->
mergedInputs[indexedInput.index] = broadcastArraysWithOne[i]
}

for (indexedInput in arraysWithoutOne) {
mergedInputs[indexedInput.index] = indexedInput.value
}
val mergedInputs = broadcastArraysWithOne.plus(arraysWithoutOne).sortedBy { it.index }.map { it.value }

return recurrentBack(mergedInputs, destination)
}

val viewedArraysWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) }
val fixedViewsWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) }

for (i in 0 until destination.shape[0]) {
val viewedArraysWithoutOne = arraysWithoutOne.map { it.copy(value = it.value.view(i)) }
val viewedDestination = destination.viewMutable(i)
for (indexedInput in fixedViewsWithOne) {
mergedInputs[indexedInput.index] = indexedInput.value
}

val mergedViewedInputs = viewedArraysWithOne.plus(viewedArraysWithoutOne).sortedBy { it.index }.map { it.value }
for (i in 0 until destination.shape[0]) {
for (indexedInput in arraysWithoutOne) {
mergedInputs[indexedInput.index] = indexedInput.value.view(i)
}

recurrentBack(mergedViewedInputs, viewedDestination)
val viewedDestination = destination.viewMutable(i)
recurrentBack(mergedInputs, viewedDestination)
}
}
}

0 comments on commit 80e3106

Please sign in to comment.