Skip to content

Commit

Permalink
JBAI-197 [ndarray] Optimized FastGelu operation by reusing temporary …
Browse files Browse the repository at this point in the history
…blocks.
  • Loading branch information
dmitriyb committed Oct 15, 2024
1 parent d3bf6db commit cca1f0f
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package io.kinference.ndarray.extensions.gelu
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext
import io.kinference.ndarray.arrays.memory.storage.*
import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray
import io.kinference.ndarray.countCoroutinesByData
import io.kinference.ndarray.parallelizeByBlocks
Expand All @@ -16,6 +18,7 @@ import io.kinference.ndarray.math.FastMath
import io.kinference.ndarray.math.exp
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import kotlin.coroutines.coroutineContext
import kotlin.math.*

@GenerateNameFromPrimitives
Expand All @@ -27,11 +30,15 @@ internal suspend fun fastGeluPrimitive(input: PrimitiveNDArray, bias: PrimitiveN

val blockSize = input.array.blockSize

val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048)
val temporaryBlocksExp = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize)
?: Array(coroutineCount) { PrimitiveArray(blockSize) }

// Constant 2048 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, _ ->
val temporaryBlockExp = PrimitiveArray(blockSize)
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex ->
val temporaryBlockExp = temporaryBlocksExp[coroutineIndex]
for (blockIdx in blockStart until blockEnd) {
val outputBlock = outputBlocks[blockIdx]
val block = inputBlocks[blockIdx]
Expand Down

0 comments on commit cca1f0f

Please sign in to comment.