Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: BaseModel abstract class, primitive Style Transfer, functions for retrieving model metadata #44

Merged
merged 27 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a2404f5
feat: implementation of model and style transfer model on android, ad…
NorbertKlockiewicz Dec 3, 2024
fefa0da
feat: not working implementation of Model abstract class for ios
jakmro Dec 3, 2024
6a0b6d0
feat: working style transfer module example for ios
NorbertKlockiewicz Dec 4, 2024
e396d45
feat: added methods to extract input type and shape from model(ios)
NorbertKlockiewicz Dec 4, 2024
d9e7668
feat: added functions for retrieving metadata about model with exampl…
NorbertKlockiewicz Dec 5, 2024
6a7d02d
feat: bindings support for models with multiple outputs(android)
NorbertKlockiewicz Dec 5, 2024
57536a4
rename Model to BaseModel
NorbertKlockiewicz Dec 5, 2024
741fbdc
fix: wrong import
NorbertKlockiewicz Dec 5, 2024
99cf9ff
fix: CI problems with example app, remove style transfer features fro…
NorbertKlockiewicz Dec 5, 2024
e7cb9d2
refactor: android code, BaseModel forward returns now every output
NorbertKlockiewicz Dec 5, 2024
3f41c24
fix: remove locally linked android lib
NorbertKlockiewicz Dec 6, 2024
9c7edef
fix: apply requested changes
NorbertKlockiewicz Dec 8, 2024
0be2992
fix: Change iOS native bindings so they return an array of outputs (#43)
chmjkb Dec 9, 2024
dc8ea07
feat: implementation of model and style transfer model on android, ad…
NorbertKlockiewicz Dec 3, 2024
bd8baea
feat: not working implementation of Model abstract class for ios
jakmro Dec 3, 2024
c3dfb6c
feat: working style transfer module example for ios
NorbertKlockiewicz Dec 4, 2024
07dc968
feat: added methods to extract input type and shape from model(ios)
NorbertKlockiewicz Dec 4, 2024
00141f2
feat: added functions for retrieving metadata about model with exampl…
NorbertKlockiewicz Dec 5, 2024
cda256a
feat: bindings support for models with multiple outputs(android)
NorbertKlockiewicz Dec 5, 2024
ede6105
rename Model to BaseModel
NorbertKlockiewicz Dec 5, 2024
2ccb419
fix: wrong import
NorbertKlockiewicz Dec 5, 2024
9109024
fix: CI problems with example app, remove style transfer features fro…
NorbertKlockiewicz Dec 5, 2024
3a39d05
refactor: android code, BaseModel forward returns now every output
NorbertKlockiewicz Dec 5, 2024
cf2a1a5
fix: remove locally linked android lib
NorbertKlockiewicz Dec 6, 2024
75681d7
fix: apply requested changes
NorbertKlockiewicz Dec 8, 2024
c6b8fe3
Merge branch '@norbertklockiewicz/style-transfer' of https://github.c…
NorbertKlockiewicz Dec 9, 2024
3444d18
Update ios/RnExecutorch/models/BaseModel.mm
NorbertKlockiewicz Dec 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 21 additions & 41 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
@@ -1,57 +1,35 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.TensorUtils
import okhttp3.OkHttpClient
import org.pytorch.executorch.Module
import org.pytorch.executorch.Tensor
import java.net.URL

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module
private val client = OkHttpClient()

override fun getName(): String {
return NAME
}

private fun downloadModel(
url: String, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
) {
Fetcher.downloadResource(reactApplicationContext,
client,
url,
resourceType,
false,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
}
})
}

override fun loadModule(modelPath: String, promise: Promise) {
try {
downloadModel(
modelPath, ResourceType.MODEL
) { path, error ->
if (error != null) {
promise.reject(error.message!!, "-1")
return@downloadModel
}

module = Module.load(path)
promise.resolve(0)
Fetcher.downloadModel(
reactApplicationContext,
modelPath,
) { path, error ->
if (error != null) {
promise.reject(error.message!!, ETError.InvalidModelPath.toString())
return@downloadModel
}
} catch (e: Exception) {
promise.reject(e.message!!, "-1")

module = Module.load(path)
promise.resolve(0)
return@downloadModel
}
}

Expand All @@ -75,19 +53,21 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
val executorchInput =
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())

lateinit var result: Tensor
module.forward(executorchInput)[0].toTensor().also { result = it }
val result = module.forward(executorchInput)
val resultArray = Arguments.createArray()

for (evalue in result) {
resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor()))
}

promise.resolve(ArrayUtils.createReadableArray(result))
promise.resolve(resultArray)
return
} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", "18")
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
return
} catch (e: Exception) {
//Executorch forward method throws an exception with a message: "Method forward failed with code XX"
val exceptionCode = e.message!!.substring(e.message!!.length - 2)
promise.reject("Forward Failed Execution", exceptionCode)
promise.reject("Forward Failed Execution", e.message!!)
return
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@ import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.llms.ChatRole
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
import okhttp3.OkHttpClient
import org.pytorch.executorch.LlamaCallback
import org.pytorch.executorch.LlamaModule
import java.net.URL

class RnExecutorchModule(reactContext: ReactApplicationContext) :
NativeRnExecutorchSpec(reactContext), LlamaCallback {

private var llamaModule: LlamaModule? = null
private var tempLlamaResponse = StringBuilder()
private lateinit var conversationManager: ConversationManager
private val client = OkHttpClient()
private var isFetching = false

override fun getName(): String {
Expand Down Expand Up @@ -51,7 +48,7 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
callback: (path: String?, error: Exception?) -> Unit,
) {
Fetcher.downloadResource(
reactApplicationContext, client, url, resourceType, isLargeFile,
reactApplicationContext, url, resourceType, isLargeFile,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class RnExecutorchPackage : TurboReactPackage() {
RnExecutorchModule(reactContext)
} else if (name == ETModule.NAME) {
ETModule(reactContext)
} else if(name == StyleTransfer.NAME){
StyleTransfer(reactContext)
} else {
null
}
Expand All @@ -40,6 +42,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[StyleTransfer.NAME] = ReactModuleInfo(
StyleTransfer.NAME,
StyleTransfer.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
49 changes: 49 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.swmansion.rnexecutorch

import android.graphics.BitmapFactory
import android.net.Uri
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.StyleTransferModel
import com.swmansion.rnexecutorch.utils.BitmapUtils
import com.swmansion.rnexecutorch.utils.ETError

class StyleTransfer(reactContext: ReactApplicationContext) :
NativeStyleTransferSpec(reactContext) {

private lateinit var styleTransferModel: StyleTransferModel

companion object {
const val NAME = "StyleTransfer"
}

override fun loadModule(modelSource: String, promise: Promise) {
try {
styleTransferModel = StyleTransferModel(reactApplicationContext)
styleTransferModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val uri = Uri.parse(input)
val bitmapInputStream = reactApplicationContext.contentResolver.openInputStream(uri)
val rawBitmap = BitmapFactory.decodeStream(bitmapInputStream)
bitmapInputStream!!.close()

val output = styleTransferModel.runModel(rawBitmap)
val outputUri = BitmapUtils.saveToTempFile(output, "test")

promise.resolve(outputUri.toString())
}catch(e: Exception){
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.swmansion.rnexecutorch.models

import android.content.Context
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.Fetcher
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Module


abstract class BaseModel<Input, Output>(val context: Context) {
protected lateinit var module: Module

fun loadModel(modelSource: String) {
Fetcher.downloadModel(
context,
modelSource
) { path, error ->
if (error != null) {
throw Error(error.message!!)
}

module = Module.load(path)
}
}

protected fun forward(input: EValue): Array<EValue> {
try {
val result = module.forward(input)
return result
} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
throw Error(ETError.InvalidArgument.code.toString())
} catch (e: Exception) {
throw Error(e.message!!)
}
}

abstract fun runModel(input: Input): Output

protected abstract fun preprocess(input: Input): Input

protected abstract fun postprocess(input: Output): Output
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.swmansion.rnexecutorch.models

import android.graphics.Bitmap
import android.util.Log
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.TensorUtils
import org.pytorch.executorch.EValue

class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Bitmap, Bitmap>(reactApplicationContext) {
override fun runModel(input: Bitmap): Bitmap {
val processedData = preprocess(input)
val inputTensor = TensorUtils.bitmapToFloat32Tensor(processedData)

Log.d("RnExecutorch", module.numberOfInputs.toString())
for (i in 0 until module.numberOfInputs) {
Log.d("RnExecutorch", module.getInputType(i).toString())
for(shape in module.getInputShape(i)){
Log.d("RnExecutorch", shape.toString())
}
}

Log.d("RnExecutorch", module.numberOfOutputs.toString())
for(i in 0 until module.numberOfOutputs){
Log.d("RnExecutorch", module.getOutputType(i).toString())
for(shape in module.getOutputShape(i)){
Log.d("RnExecutorch", shape.toString())
}
}

val outputTensor = forward(EValue.from(inputTensor))
val outputData = postprocess(TensorUtils.float32TensorToBitmap(outputTensor[0].toTensor()))

return outputData
}

override fun preprocess(input: Bitmap): Bitmap {
val inputBitmap = Bitmap.createScaledBitmap(
input,
640, 640, true
)
return inputBitmap
}

override fun postprocess(input: Bitmap): Bitmap {
val scaledUpBitmap = Bitmap.createScaledBitmap(
input,
1280, 1280, true
)
return scaledUpBitmap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.swmansion.rnexecutorch.utils

import android.graphics.Bitmap
import android.graphics.Matrix
import android.net.Uri
import androidx.core.net.toUri
import java.io.File
import java.io.FileOutputStream
import java.io.IOException

class BitmapUtils {
companion object {
fun saveToTempFile(bitmap: Bitmap, fileName: String): Uri {
val tempFile = File.createTempFile(fileName, ".png")
var outputStream : FileOutputStream? = null
try {
outputStream = FileOutputStream(tempFile)
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream)
} catch (e: IOException) {
e.printStackTrace()
}
finally {
outputStream?.close()
}
return tempFile.toUri()
}

private fun rotateBitmap(bitmap: Bitmap, angle: Float): Bitmap {
val matrix = Matrix()
matrix.postRotate(angle)
return Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
}

private fun flipBitmap(bitmap: Bitmap, horizontal: Boolean, vertical: Boolean): Bitmap {
val matrix = Matrix()
matrix.preScale(
if (horizontal) -1f else 1f,
if (vertical) -1f else 1f
)
return Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
}
}

}
29 changes: 29 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.swmansion.rnexecutorch.utils

enum class ETError(val code: Int) {
InvalidModelPath(0xff),

// System errors
Ok(0x00),
Internal(0x01),
InvalidState(0x02),
EndOfMethod(0x03),

// Logical errors
NotSupported(0x10),
NotImplemented(0x11),
InvalidArgument(0x12),
InvalidType(0x13),
OperatorMissing(0x14),

// Resource errors
NotFound(0x20),
MemoryAllocationFailed(0x21),
AccessFailed(0x22),
InvalidProgram(0x23),

// Delegate errors
DelegateInvalidCompatibility(0x30),
DelegateMemoryAllocationFailed(0x31),
DelegateInvalidHandle(0x32);
}
Loading
Loading