Skip to content

Commit

Permalink
feat: [Android] replacing MLKit with TensorFlow lib
Browse files Browse the repository at this point in the history
  • Loading branch information
jaroslawkrol committed Feb 26, 2023
1 parent 9a39373 commit b8b1939
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 194 deletions.
3 changes: 1 addition & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ dependencies {
api project(":react-native-vision-camera")
implementation "androidx.camera:camera-core:1.1.0-alpha06"
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
implementation 'com.google.mlkit:object-detection-custom:17.0.0'

implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.0'
}

if (isNewArchitectureEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import com.visioncamerarealtimeobjectdetection.realtimeobjectdetectionprocessor.

class RealtimeObjectDetectionProcessorPluginPackage : ReactPackage {
override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> {
FrameProcessorPlugin.register(RealtimeObjectDetectionProcessorPlugin())
FrameProcessorPlugin.register(RealtimeObjectDetectionProcessorPlugin(reactContext))
return emptyList()
}

override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
return emptyList()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,109 +1,114 @@
package com.visioncamerarealtimeobjectdetection.realtimeobjectdetectionprocessor

import android.graphics.Matrix
import android.graphics.RectF
import androidx.camera.core.ImageProxy
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableMap
import com.facebook.react.bridge.WritableNativeArray
import com.facebook.react.bridge.WritableNativeMap
import com.facebook.react.bridge.ReadableMap
import com.google.android.gms.tasks.Tasks
import com.google.mlkit.common.model.LocalModel
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.objects.ObjectDetection
import com.google.mlkit.vision.objects.ObjectDetector
import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions
import com.google.android.odml.image.MediaMlImageBuilder
import com.mrousavy.camera.frameprocessor.FrameProcessorPlugin
import org.tensorflow.lite.task.core.BaseOptions
import org.tensorflow.lite.task.vision.detector.ObjectDetector

class RealtimeObjectDetectionProcessorPlugin : FrameProcessorPlugin("detectObjects") {
class RealtimeObjectDetectionProcessorPlugin(reactContext: ReactApplicationContext) :
FrameProcessorPlugin("detectObjects") {
private val _context: ReactApplicationContext = reactContext
private var _detector: ObjectDetector? = null

fun rotateRect(rect: RectF, degrees: Int): RectF {
val matrix = Matrix()
matrix.postRotate(degrees.toFloat(), rect.centerX(), rect.centerY())
val rotatedRect = RectF(rect)
matrix.mapRect(rotatedRect)
return rotatedRect
}

fun getDetectorWithModelFile(config: ReadableMap): ObjectDetector {
if (_detector == null) {
val modelFile = config.getString("modelFile")
val localModel = LocalModel.Builder().setAssetFilePath("custom/$modelFile").build()

val classificationConfidenceThreshold = config.getDouble("classificationConfidenceThreshold")
val maxPerObjectLabelCount = config.getInt("maxPerObjectLabelCount")
val customObjectDetectorOptions =
CustomObjectDetectorOptions.Builder(localModel)
.setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
.enableClassification()
.enableMultipleObjects()
.setClassificationConfidenceThreshold(classificationConfidenceThreshold.toFloat())
.setMaxPerObjectLabelCount(maxPerObjectLabelCount)
.build()

_detector = ObjectDetection.getClient(customObjectDetectorOptions)

val scoreThreshold = config.getDouble("scoreThreshold").toFloat()
val maxResults = config.getInt("maxResults")
val numThreads = config.getInt("numThreads")

val baseOptionsBuilder = BaseOptions.builder().setNumThreads(numThreads)

val optionsBuilder =
ObjectDetector.ObjectDetectorOptions.builder()
.setBaseOptions(baseOptionsBuilder.build())
.setScoreThreshold(scoreThreshold)
.setMaxResults(maxResults)

_detector =
ObjectDetector.createFromFileAndOptions(
_context,
"custom/$modelFile",
optionsBuilder.build()
)
}
return _detector!!
}

override fun callback(frame: ImageProxy, params: Array<Any>): WritableNativeArray {
val mediaImage = frame.image
if (mediaImage != null) {
val config = params[0] as ReadableMap;
val image = InputImage.fromMediaImage(mediaImage, frame.imageInfo.rotationDegrees)
val task = getDetectorWithModelFile(config).process(image)
val results = WritableNativeArray()

val frameWidth =
if (frame.imageInfo.rotationDegrees == 90 || frame.imageInfo.rotationDegrees == 270)
mediaImage.width
else mediaImage.height
val frameHeight =
if (frame.imageInfo.rotationDegrees == 90 || frame.imageInfo.rotationDegrees == 270)
mediaImage.height
else mediaImage.width

try {
val objects = Tasks.await(task)

for (detectedObject in objects) {
val labels = WritableNativeArray()

for (label in detectedObject.labels) {
val labelMap = WritableNativeMap()

labelMap.putInt("index", label.index)
labelMap.putString("label", label.text)
labelMap.putDouble("confidence", label.confidence.toDouble())

labels.pushMap(labelMap)
}

if (labels.size() > 0) {
val objectMap = WritableNativeMap()

objectMap.putArray("labels", labels)
objectMap.putDouble(
"top",
(detectedObject.boundingBox.top.toFloat() / frameWidth).toDouble()
)
objectMap.putDouble(
"left",
(detectedObject.boundingBox.left.toFloat() / frameHeight).toDouble()
)
objectMap.putDouble(
"width",
((detectedObject.boundingBox.right - detectedObject.boundingBox.left)
.toFloat() / frameHeight)
.toDouble()
)
objectMap.putDouble(
"height",
((detectedObject.boundingBox.bottom - detectedObject.boundingBox.top)
.toFloat() / frameWidth)
.toDouble()
)

results.pushMap(objectMap)
}
}

return results
} catch (e: Exception) {
e.printStackTrace()
}

if (mediaImage == null) {
return WritableNativeArray()
}

return WritableNativeArray()
val config = params[0] as ReadableMap

val mlImage = MediaMlImageBuilder(mediaImage).build()

val frameWidth =
if (frame.imageInfo.rotationDegrees == 90 || frame.imageInfo.rotationDegrees == 270)
mediaImage.width
else mediaImage.height
val frameHeight =
if (frame.imageInfo.rotationDegrees == 90 || frame.imageInfo.rotationDegrees == 270)
mediaImage.height
else mediaImage.width

val results = WritableNativeArray()
val detectedObjects = getDetectorWithModelFile(config).detect(mlImage)

for (detectedObject in detectedObjects) {
val labels = WritableNativeArray()

for (label in detectedObject.categories) {
val labelMap = WritableNativeMap()

labelMap.putInt("index", label.index)
labelMap.putString("label", label.label)
labelMap.putDouble("confidence", label.score.toDouble())

labels.pushMap(labelMap)
}

if (labels.size() > 0) {
val objectMap = WritableNativeMap()

objectMap.putArray("labels", labels)

val boundingBox =
rotateRect(detectedObject.boundingBox, frame.imageInfo.rotationDegrees)

objectMap.putDouble("top", (boundingBox.top.toFloat() / frameHeight).toDouble())
objectMap.putDouble("left", (boundingBox.left.toFloat() / frameWidth).toDouble())
objectMap.putDouble(
"width",
((boundingBox.right - boundingBox.left).toFloat() / frameWidth).toDouble()
)
objectMap.putDouble(
"height",
((boundingBox.bottom - boundingBox.top).toFloat() / frameHeight).toDouble()
)

results.pushMap(objectMap)
}
}
return results
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
3D06205CB3467253F64A37CC /* libPods-VisionCameraRealtimeObjectDetectionExample.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 114901553BFB5A046B3E2EAA /* libPods-VisionCameraRealtimeObjectDetectionExample.a */; };
81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 81AB9BB72411601600AC10FF /* LaunchScreen.storyboard */; };
92040DB1035EEF7FED988466 /* libPods-VisionCameraRealtimeObjectDetectionExample-VisionCameraRealtimeObjectDetectionExampleTests.a in Frameworks */ = {isa = PBXBuildFile; fileRef = B10E3F3BE4CB759E89FD8E56 /* libPods-VisionCameraRealtimeObjectDetectionExample-VisionCameraRealtimeObjectDetectionExampleTests.a */; };
F3A2FA63A5604E6BA3CE5459 /* efficientdet-lite-bear.tflite in Resources */ = {isa = PBXBuildFile; fileRef = F7D3C5600E2749F79FDB4245 /* efficientdet-lite-bear.tflite */; };
/* End PBXBuildFile section */

/* Begin PBXContainerItemProxy section */
Expand Down Expand Up @@ -47,6 +48,7 @@
B10E3F3BE4CB759E89FD8E56 /* libPods-VisionCameraRealtimeObjectDetectionExample-VisionCameraRealtimeObjectDetectionExampleTests.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-VisionCameraRealtimeObjectDetectionExample-VisionCameraRealtimeObjectDetectionExampleTests.a"; sourceTree = BUILT_PRODUCTS_DIR; };
D0CCD4BBBDE995D6F19ABE06 /* Pods-VisionCameraRealtimeObjectDetectionExample.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-VisionCameraRealtimeObjectDetectionExample.release.xcconfig"; path = "Target Support Files/Pods-VisionCameraRealtimeObjectDetectionExample/Pods-VisionCameraRealtimeObjectDetectionExample.release.xcconfig"; sourceTree = "<group>"; };
ED297162215061F000B7C4FE /* JavaScriptCore.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = JavaScriptCore.framework; path = System/Library/Frameworks/JavaScriptCore.framework; sourceTree = SDKROOT; };
F7D3C5600E2749F79FDB4245 /* efficientdet-lite-bear.tflite */ = {isa = PBXFileReference; name = "efficientdet-lite-bear.tflite"; path = "../assets/model/efficientdet-lite-bear.tflite"; sourceTree = "<group>"; fileEncoding = undefined; lastKnownFileType = unknown; explicitFileType = undefined; includeInIndex = 0; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -92,6 +94,7 @@
9ABB0012D599464DB3A16904 /* model_2.tflite */,
0C1B6D8B922444ADAC2EF124 /* lite-model_yolo-v5-tflite_tflite_model_1.tflite */,
58041736526C40F3839679C0 /* detect.tflite */,
F7D3C5600E2749F79FDB4245 /* efficientdet-lite-bear.tflite */,
);
name = Resources;
sourceTree = "<group>";
Expand Down Expand Up @@ -258,6 +261,7 @@
files = (
81AB9BB82411601600AC10FF /* LaunchScreen.storyboard in Resources */,
13B07FBF1A68108700A75B9A /* Images.xcassets in Resources */,
F3A2FA63A5604E6BA3CE5459 /* efficientdet-lite-bear.tflite in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
14 changes: 7 additions & 7 deletions example/src/components/ObjectDetector/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ const ObjectDetector: React.FC<Props> = ({ device }) => {
const [objects, setObjects] = useState<DetectedObject[]>([]);

const frameProcessorConfig: FrameProcessorConfig = {
modelFile: 'model.tflite',
classificationConfidenceThreshold: 0.4,
maxPerObjectLabelCount: 1,
modelFile: 'efficientdet-lite-bear.tflite',
scoreThreshold: 0.4,
maxResults: 1,
numThreads: 4,
};

const { width, height } = Dimensions.get('window');
Expand Down Expand Up @@ -50,9 +51,6 @@ const ObjectDetector: React.FC<Props> = ({ device }) => {
style={StyleSheet.absoluteFill}
device={device}
isActive={true}
// preset={'medium'}
// colorSpace="srgb"
video={true}
/>
{objects?.map(
(
Expand All @@ -64,7 +62,9 @@ const ObjectDetector: React.FC<Props> = ({ device }) => {
style={[styles.detectionFrame, { top, left, width, height }]}
>
<Text style={styles.detectionFrameLabel}>
{labels.map((label) => `${label.label} (${label.confidence})`).join(',')}
{labels
.map((label) => `${label.label} (${label.confidence})`)
.join(',')}
</Text>
</View>
)
Expand Down
Loading

0 comments on commit b8b1939

Please sign in to comment.