Skip to content

Commit

Permalink
JBAI-4393 [core, ndarray] Added getPrimitiveBlock extension functions…
Browse files Browse the repository at this point in the history
… for better primitive types handling: this solution gives less double primitive array allocations when Array<Any> changes to actual type.
  • Loading branch information
dmitriyb committed Aug 29, 2024
1 parent 954f6cc commit b83f7f8
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.kinference.ndarray.arrays.memory.contexts

import io.kinference.ndarray.arrays.memory.storage.AutoArrayHandlingStorage
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveArray
import kotlin.coroutines.*

internal class AutoAllocatorContext internal constructor(
Expand All @@ -10,8 +11,4 @@ internal class AutoAllocatorContext internal constructor(

companion object Key : CoroutineContext.Key<AutoAllocatorContext>
override val key: CoroutineContext.Key<*> get() = Key

internal fun getArrays(type: DataType, size: Int, count: Int): Array<Any> {
return storage.getArrays(type, size, count)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import kotlin.coroutines.CoroutineContext

interface BaseAllocatorContext: CoroutineContext.Element

abstract class BaseAllocatorContextWithStorage<T : ArrayStorage>(protected val storage: T) : BaseAllocatorContext {
abstract class BaseAllocatorContextWithStorage<T : ArrayStorage>(internal val storage: T) : BaseAllocatorContext {
fun finalizeContext() {
storage.resetState()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package io.kinference.ndarray.arrays.memory.storage

import io.kinference.ndarray.arrays.memory.*
import io.kinference.primitives.types.DataType

internal interface TypedAutoHandlingStorage {
fun getBlock(blocksNum: Int, blockSize: Int, limiter: MemoryManager): Array<Any>
fun moveBlocksIntoUnused()
}

internal class AutoArrayHandlingStorage(private val limiter: MemoryManager) : ArrayStorage {
private val storage: List<TypedAutoHandlingStorage> = listOf(
internal class AutoArrayHandlingStorage(internal val limiter: MemoryManager) : ArrayStorage {
internal val storage: List<TypedAutoHandlingStorage> = listOf(
ByteAutoHandlingArrayStorage(),
ShortAutoHandlingArrayStorage(),
IntAutoHandlingArrayStorage(),
Expand All @@ -23,10 +21,6 @@ internal class AutoArrayHandlingStorage(private val limiter: MemoryManager) : Ar
BooleanAutoHandlingArrayStorage()
)

internal fun getArrays(type: DataType, size: Int, count: Int): Array<Any> {
return storage[type.ordinal].getBlock(blocksNum = count, blockSize = size, limiter = limiter)
}

override fun resetState() {
storage.forEach { it.moveBlocksIntoUnused() }
limiter.resetLimit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ internal class PrimitiveAutoHandlingArrayStorage : TypedAutoHandlingStorage {
private val type = DataType.CurrentPrimitive
}

override fun getBlock(blocksNum: Int, blockSize: Int, limiter: MemoryManager): Array<Any> {
fun getBlock(blocksNum: Int, blockSize: Int, limiter: MemoryManager): Array<PrimitiveArray> {
val unusedQueue = unused.getOrPut(blockSize) { ArrayDeque(blocksNum) }
val usedQueue = used.getOrPut(blockSize) { ArrayDeque(blocksNum) }

val blocks = if (limiter.checkMemoryLimitAndAdd(type, blockSize * blocksNum)) {
Array(blocksNum) {
val block = unusedQueue.removeFirstOrNull()
block?.fill(PrimitiveConstants.ZERO)
block ?: PrimitiveArray(blockSize)
unusedQueue.removeFirstOrNull()?.apply {
fill(PrimitiveConstants.ZERO)
} ?: PrimitiveArray(blockSize)
}
} else {
Array(blocksNum) {
Expand All @@ -35,7 +35,7 @@ internal class PrimitiveAutoHandlingArrayStorage : TypedAutoHandlingStorage {

usedQueue.addAll(blocks)

return blocks as Array<Any>
return blocks
}

override fun moveBlocksIntoUnused() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@file:GeneratePrimitives(DataType.ALL)
@file:Suppress("DuplicatedCode")
package io.kinference.ndarray.arrays.memory.storage

import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.*

@GenerateNameFromPrimitives
internal fun AutoArrayHandlingStorage.getPrimitiveBlock(blocksNum: Int, blockSize: Int): Array<PrimitiveArray> {
return (storage[DataType.CurrentPrimitive.ordinal] as PrimitiveAutoHandlingArrayStorage).getBlock(blocksNum = blocksNum, blockSize = blockSize, limiter = limiter)
}

@GenerateNameFromPrimitives
internal fun AutoAllocatorContext.getPrimitiveBlock(blocksNum: Int, blockSize: Int): Array<PrimitiveArray> {
return storage.getPrimitiveBlock(blocksNum = blocksNum, blockSize = blockSize)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
package io.kinference.ndarray.arrays.tiled

import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.*
import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext
import io.kinference.ndarray.arrays.memory.storage.*
import io.kinference.ndarray.arrays.pointers.PrimitivePointer
import io.kinference.ndarray.arrays.pointers.accept
import io.kinference.ndarray.blockSizeByStrides
Expand Down Expand Up @@ -59,11 +59,9 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>) {
require(size % blockSize == 0) { "Size must divide blockSize" }

val blocksNum = if (blockSize == 0) 0 else size / blockSize
val blocks = coroutineContext[AutoAllocatorContext.Key]?.getPrimitiveBlock(blocksNum, blockSize) ?: Array(blocksNum) { PrimitiveArray(blockSize) }

val coroutineContext = coroutineContext[AutoAllocatorContext.Key]
val blocks = coroutineContext?.getArrays(type, blockSize, blocksNum) ?: Array(blocksNum) { PrimitiveArray(blockSize) }

return PrimitiveTiledArray(blocks.map { it as PrimitiveArray }.toTypedArray())
return PrimitiveTiledArray(blocks)
}

suspend operator fun invoke(size: Int, blockSize: Int, init: (InlineInt) -> PrimitiveType) : PrimitiveTiledArray {
Expand Down

0 comments on commit b83f7f8

Please sign in to comment.