Skip to content

Commit

Permalink
JBAI-4393 [core, ndarray] Refactored memory management and array hand…
Browse files Browse the repository at this point in the history
…ling: streamlined array type handling and improved memory limit checks within create and reset methods; KIModel predict improved for NoAllocator case.
  • Loading branch information
dmitriyb committed Aug 15, 2024
1 parent 2d7c310 commit 9caf75c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class KIModel(
val name: String,
val opSet: OperatorSetRegistry,
val graph: KIGraph,
memoryLimiter: MemoryLimiter = MemoryLimiters.NoAllocator,
private val memoryLimiter: MemoryLimiter = MemoryLimiters.NoAllocator,
parallelismLimit: Int = PlatformUtils.cores,
) : Model<KIONNXData<*>>, Profilable, Cacheable {
private val profiles: MutableList<ProfilingContext> = ArrayList()

@OptIn(ExperimentalCoroutinesApi::class)
private val dispatcher: CoroutineDispatcher = Dispatchers.Default.limitedParallelism(parallelismLimit)
private val modelArrayStorage: ModelArrayStorage = ModelArrayStorage(memoryLimiter)
private val modelArrayStorage: ModelArrayStorage = ModelArrayStorage(MemoryLimiters.Default)

override fun addProfilingContext(name: String): ProfilingContext = ProfilingContext(name).apply { profiles.add(this) }
override fun analyzeProfilingResults(): ProfileAnalysisEntry = profiles.analyze("Model $name")
Expand All @@ -44,15 +44,21 @@ class KIModel(
coreReserved = true
}

val allocatorContext = modelArrayStorage.createAllocatorContext()
val mixedContext = allocatorContext + limiterContext
if (memoryLimiter == MemoryLimiters.NoAllocator) {
withContext(limiterContext) {
return@withContext graph.execute(input, contexts)
}
} else {
val allocatorContext = modelArrayStorage.createAllocatorContext()
val mixedContext = allocatorContext + limiterContext

withContext(mixedContext) {
val coroutineContext = coroutineContext[AllocatorContext.Key]!!
val execResult = graph.execute(input, contexts)
val copies = execResult.map { it.clone(it.name) }.toList()
coroutineContext.closeAllocated()
copies
withContext(mixedContext) {
val coroutineContext = coroutineContext[AllocatorContext.Key]!!
val execResult = graph.execute(input, contexts)
val copies = execResult.map { it.clone(it.name) }.toList()
coroutineContext.closeAllocated()
return@withContext copies
}
}
} finally {
if (coreReserved) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package io.kinference.ndarray.arrays

enum class ArrayTypes(val index: Int, val size: Int) {
ByteArray(0, Byte.SIZE_BYTES),
UByteArray(1, UByte.SIZE_BYTES),
ShortArray(2, Short.SIZE_BYTES),
UShortArray(3, UShort.SIZE_BYTES),
IntArray(4, Int.SIZE_BYTES),
UIntArray(5, UInt.SIZE_BYTES),
LongArray(6, Long.SIZE_BYTES),
ULongArray(7, ULong.SIZE_BYTES),
FloatArray(8, Float.SIZE_BYTES),
DoubleArray(9, Double.SIZE_BYTES),
BooleanArray(10, 1);
ByteArrayType(0, Byte.SIZE_BYTES),
UByteArrayType(1, UByte.SIZE_BYTES),
ShortArrayType(2, Short.SIZE_BYTES),
UShortArrayType(3, UShort.SIZE_BYTES),
IntArrayType(4, Int.SIZE_BYTES),
UIntArrayType(5, UInt.SIZE_BYTES),
LongArrayType(6, Long.SIZE_BYTES),
ULongArrayType(7, ULong.SIZE_BYTES),
FloatArrayType(8, Float.SIZE_BYTES),
DoubleArrayType(9, Double.SIZE_BYTES),
BooleanArrayType(10, 1);

companion object {
fun sizeInBytes(index: Int, arraySize: Int): Long = entries[index].size * arraySize.toLong()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@ data class AllocatorContext internal constructor(
override val key: CoroutineContext.Key<*> get() = Key

internal fun getArrayContainers(type: ArrayTypes, size: Int, count: Int): Array<Any> {
return if (limiter !is NoAllocatorMemoryLimiter) {
Array(count) { unusedContainers.getArrayContainer(type, size) }
} else {
Array(count) { unusedContainers.create(type, size) }
}
}

fun closeOperator() {
unusedContainers.moveUsedArrays()
return Array(count) { unusedContainers.getArrayContainer(type, size) }
}

fun closeAllocated() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import io.kinference.ndarray.arrays.ArrayTypes

internal class ArrayStorage(typeLength: Int, sizeLength: Int, private val limiter: MemoryLimiter) {
/**
* This is a storage for arrays which are available for retrieving
*
* Structure is as follows:
* 1. Array by predefined types (all types are known compiled time)
* 2. Array by size. Starting with 'INIT_SIZE_VALUE' element and grow it doubling (typically there are no more than 16 different sizes)
Expand All @@ -12,30 +14,52 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int, private val limite
private var storageUnused: Array<Array<ArrayDeque<Any>>> =
Array(typeLength) { Array(sizeLength) { ArrayDeque() } }

/**
* This is a storage for arrays which are currently in use.
* They should be moved back into unused storage when there is no need for them.
*
* Structure is as follows:
* 1. Array by predefined types (all types are known compiled time)
* 2. Array by size.
* Starting with 'INIT_SIZE_VALUE' element and grow it doubling (typically there are no more than 16 different sizes)
* 3. Queue of array containers (used as FIFO)
*/
private var storageUsed: Array<Array<ArrayDeque<Any>>> =
Array(typeLength) { Array(sizeLength) { ArrayDeque() } }

private var sizeIndices: IntArray = IntArray(typeLength)
private var sizes: Array<IntArray> = Array(typeLength) { IntArray(sizeLength) }

internal fun getArrayContainer(type: ArrayTypes, size: Int): Any {
return if (limiter.checkMemoryLimitAndAdd(ArrayTypes.sizeInBytes(type.index, size))) {
val tIndex = type.index
val sIndex = getSizeIndex(tIndex, size)
val array = storageUnused[tIndex][sIndex].removeFirstOrNull()?.also(::resetArray)
?: create(type, size)

operator fun get(typeIndex: Int, sizeIndex: Int): ArrayDeque<Any> {
return storageUnused[typeIndex][sizeIndex]
storageUsed[tIndex][sIndex].addLast(array)
array
} else {
create(type, size)
}
}

fun getArrayContainer(type: ArrayTypes, size: Int): Any {
val tIndex = type.index
internal fun moveUsedArrays() {
storageUsed.forEachIndexed { typeIndex, arraysByType ->
arraysByType.forEachIndexed { sizeIndex, arrayDeque ->
arrayDeque.forEach {
storageUnused[typeIndex][sizeIndex].addLast(it)
}
arrayDeque.clear()
}
}
limiter.resetLimit()
}

private fun getSizeIndex(tIndex: Int, size: Int): Int {
val sIndex = sizes[tIndex].indexOf(size)

// Checking that we have this array size in our storage for this type
val idx = if (sIndex != -1) {
val array = storageUnused[tIndex][sIndex].removeFirstOrNull()
array?.let {
resetArray(it)
limiter.deductMemory((type.size * size).toLong())
storageUsed[tIndex][sIndex].addLast(it)
return it
}
return if (sIndex != -1) {
sIndex
} else {
if (sizeIndices[tIndex] >= storageUnused[tIndex].size)
Expand All @@ -45,22 +69,6 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int, private val limite
sizes[tIndex][idx] = size
idx
}

val array = create(type, size)
storageUsed[tIndex][idx].addLast(array)

return array
}

fun moveUsedArrays() {
storageUsed.forEachIndexed { typeIndex, arraysByType ->
arraysByType.forEachIndexed { sizeIndex, arrayDeque ->
arrayDeque.forEach {
storageUnused[typeIndex][sizeIndex].addLast(it)
}
arrayDeque.clear()
}
}
}

private fun grow(typeIndex: Int) {
Expand All @@ -78,37 +86,36 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int, private val limite
sizes[typeIndex] = sizes[typeIndex].copyOf(newSize)
}

fun create(type: ArrayTypes, size: Int): Any {
private fun create(type: ArrayTypes, size: Int): Any {
return when (type) {
ArrayTypes.ByteArray -> ByteArray(size) // 8-bit signed
ArrayTypes.UByteArray -> UByteArray(size) // 8-bit unsigned
ArrayTypes.ShortArray -> ShortArray(size) // 16-bit signed
ArrayTypes.UShortArray -> UShortArray(size) // 16-bit unsigned
ArrayTypes.IntArray -> IntArray(size) // 32-bit signed
ArrayTypes.UIntArray -> UIntArray(size) // 32-bit unsigned
ArrayTypes.LongArray -> LongArray(size) // 64-bit signed
ArrayTypes.ULongArray -> ULongArray(size) // 64-bit unsigned
ArrayTypes.FloatArray -> FloatArray(size)
ArrayTypes.DoubleArray -> DoubleArray(size)
ArrayTypes.BooleanArray -> BooleanArray(size)
ArrayTypes.ByteArrayType -> ByteArray(size) // 8-bit signed
ArrayTypes.UByteArrayType -> UByteArray(size) // 8-bit unsigned
ArrayTypes.ShortArrayType -> ShortArray(size) // 16-bit signed
ArrayTypes.UShortArrayType -> UShortArray(size) // 16-bit unsigned
ArrayTypes.IntArrayType -> IntArray(size) // 32-bit signed
ArrayTypes.UIntArrayType -> UIntArray(size) // 32-bit unsigned
ArrayTypes.LongArrayType -> LongArray(size) // 64-bit signed
ArrayTypes.ULongArrayType -> ULongArray(size) // 64-bit unsigned
ArrayTypes.FloatArrayType -> FloatArray(size)
ArrayTypes.DoubleArrayType -> DoubleArray(size)
ArrayTypes.BooleanArrayType -> BooleanArray(size)
else -> throw IllegalArgumentException("Unsupported array type")
}
}

private fun resetArray(array: Any) {
private fun resetArray(array: Any): Unit =
when (array) {
is ByteArray -> array.fill(0) // 8-bit signed
is UByteArray -> array.fill(0u) // 8-bit unsigned
is ShortArray -> array.fill(0) // 16-bit signed
is UShortArray -> array.fill(0u) // 16-bit unsigned
is IntArray -> array.fill(0) // 32-bit signed
is UIntArray -> array.fill(0u) // 32-bit unsigned
is LongArray -> array.fill(0L) // 64-bit signed
is ULongArray -> array.fill(0U) // 64-bit unsigned
is ByteArray -> array.fill(0) // 8-bit signed
is UByteArray -> array.fill(0u) // 8-bit unsigned
is ShortArray -> array.fill(0) // 16-bit signed
is UShortArray -> array.fill(0u) // 16-bit unsigned
is IntArray -> array.fill(0) // 32-bit signed
is UIntArray -> array.fill(0u) // 32-bit unsigned
is LongArray -> array.fill(0L) // 64-bit signed
is ULongArray -> array.fill(0U) // 64-bit unsigned
is FloatArray -> array.fill(0.0f)
is DoubleArray -> array.fill(0.0)
is BooleanArray -> array.fill(false)
else -> throw IllegalArgumentException("Unsupported array type")
}
}
}
Original file line number Diff line number Diff line change
@@ -1,57 +1,45 @@
package io.kinference.ndarray.arrays.memory

import io.kinference.utils.PlatformUtils
import kotlinx.atomicfu.AtomicLong
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.*

interface MemoryLimiter {
/**
* Checks if the memory limit allows adding the specified amount of memory and performs the addition.
* Checks if the memory limit allows adding the specified amount of memory and performs the addition
*
* @param added the memory in bytes to add
* @return true if the memory was added successfully and false if adding the memory exceeds the memory limit
*/
fun checkMemoryLimitAndAdd(added: Long): Boolean

/**
* Deducts the specified amount of memory from the memory limiter.
*
* @param deducted the memory in bytes to deduct from the memory limiter
* Resets the used memory into 0L
*/
fun deductMemory(deducted: Long)
fun resetLimit()
}

class BaseMemoryLimiter(private val memoryLimit: Long) : MemoryLimiter {
class BaseMemoryLimiter internal constructor(private val memoryLimit: Long) : MemoryLimiter {
private var usedMemory: AtomicLong = atomic(0L)

override fun checkMemoryLimitAndAdd(added: Long): Boolean {
val currentMemory = usedMemory.addAndGet(added)
return if (currentMemory > memoryLimit) {
usedMemory.addAndGet(-added)
false
} else true
// Attempt to add memory and check the limit
val successful = usedMemory.getAndUpdate { current ->
if (current + added > memoryLimit) current else current + added
} != usedMemory.value // Check if the update was successful

return successful
}

override fun deductMemory(deducted: Long) {
usedMemory.addAndGet(-deducted)
override fun resetLimit() {
usedMemory.value = 0L
}
}

object MemoryLimiters {
val Default: MemoryLimiter = BaseMemoryLimiter((PlatformUtils.maxHeap * 0.3).toLong())
val NoAllocator: MemoryLimiter = NoAllocatorMemoryLimiter
val NoAllocator: MemoryLimiter = BaseMemoryLimiter(0L)

fun customLimiter(memoryLimit: Long): MemoryLimiter {
return BaseMemoryLimiter(memoryLimit)
}
}

internal object NoAllocatorMemoryLimiter : MemoryLimiter {
override fun checkMemoryLimitAndAdd(added: Long): Boolean {
return false
}

override fun deductMemory(deducted: Long) {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>) {
}

companion object {
val type: ArrayTypes = ArrayTypes.valueOf(PrimitiveArray::class.simpleName!!)
val type: ArrayTypes = ArrayTypes.valueOf(PrimitiveArray::class.simpleName!! + "Type")

suspend operator fun invoke(strides: Strides): PrimitiveTiledArray {
val blockSize = blockSizeByStrides(strides)
Expand Down Expand Up @@ -127,19 +127,16 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>) {
blocks[blockIdx][blockOff] = value
}

suspend fun copyOf(): PrimitiveTiledArray {
// val copyArray = PrimitiveTiledArray(size, blockSize)
fun copyOf(): PrimitiveTiledArray {
val copyBlocks = Array(blocksNum) { PrimitiveArray(blockSize) }

for (blockNum in 0 until blocksNum) {
val thisBlock = this.blocks[blockNum]
// val destBlock = copyArray.blocks[blockNum]
val destBlock = copyBlocks[blockNum]

thisBlock.copyInto(destBlock)
}

// return copyArray
return PrimitiveTiledArray(copyBlocks)
}

Expand Down

0 comments on commit 9caf75c

Please sign in to comment.