Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@jakmro/classification android #55

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/Classification.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.classification.ClassificationModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

class Classification(reactContext: ReactApplicationContext) :
NativeClassificationSpec(reactContext) {

private lateinit var classificationModel: ClassificationModel

companion object {
const val NAME = "Classification"
init {
if(!OpenCVLoader.initLocal()){
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}
}

override fun loadModule(modelSource: String, promise: Promise) {
try {
classificationModel = ClassificationModel(reactApplicationContext)
classificationModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val image = ImageProcessor.readImage(input)
val output = classificationModel.runModel(image)

val writableMap: WritableMap = Arguments.createMap()

for ((key, value) in output) {
writableMap.putDouble(key, value.toDouble())
}

promise.resolve(writableMap)
}catch(e: Exception){
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class RnExecutorchPackage : TurboReactPackage() {
ETModule(reactContext)
} else if (name == StyleTransfer.NAME) {
StyleTransfer(reactContext)
} else {
} else if (name == Classification.NAME) {
Classification(reactContext)
}
else {
null
}

Expand Down Expand Up @@ -51,6 +54,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[Classification.NAME] = ReactModuleInfo(
Classification.NAME,
Classification.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ abstract class BaseModel<Input, Output>(val context: Context) {

abstract fun runModel(input: Input): Output

protected abstract fun preprocess(input: Input): Input
protected abstract fun preprocess(input: Input): EValue

protected abstract fun postprocess(input: Tensor): Output
protected abstract fun postprocess(output: Array<EValue>): Output
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.EValue


class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Mat>(reactApplicationContext) {
Expand All @@ -19,22 +20,23 @@ class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : Bas
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): Mat {
override fun preprocess(input: Mat): EValue {
originalSize = input.size()
Imgproc.resize(input, input, getModelImageSize())
return input
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(input: Tensor): Mat {
override fun postprocess(output: Array<EValue>): Mat {
val tensor = output[0].toTensor()
val modelShape = getModelImageSize()
val result = ImageProcessor.EValueToMat(input.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
val result = ImageProcessor.EValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
Imgproc.resize(result, result, originalSize)
return result
}

override fun runModel(input: Mat): Mat {
val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0))
val outputTensor = forward(inputTensor)
return postprocess(outputTensor[0].toTensor())
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.swmansion.rnexecutorch.models.classification

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.EValue
import com.swmansion.rnexecutorch.models.BaseModel


class ClassificationModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Map<String, Float>>(reactApplicationContext) {
private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
Imgproc.resize(input, input, getModelImageSize())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this is not 'in-place' wrt underlying memory

return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(output: Array<EValue>): Map<String, Float> {
val tensor = output[0].toTensor()
val probabilities = softmax(tensor.dataAsFloatArray.toTypedArray())

val result = mutableMapOf<String, Float>()

for (i in probabilities.indices) {
result[imagenet1k_v1_labels[i]] = probabilities[i]
}

return result
}

override fun runModel(input: Mat): Map<String, Float> {
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}
}
Loading
Loading