Skip to content

Commit

Permalink
@jakmro/classification ios (#54)
Browse files Browse the repository at this point in the history
## Description
Image classification for iOS 

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [x] iOS
- [ ] Android

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [x] My changes generate no new warnings
  • Loading branch information
jakmro authored Dec 17, 2024
1 parent b3c07ad commit 7484a31
Show file tree
Hide file tree
Showing 21 changed files with 1,412 additions and 23 deletions.
19 changes: 11 additions & 8 deletions examples/computer-vision/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import { useState } from 'react';
import { StyleTransferScreen } from './screens/StyleTransferScreen';
import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
import { View, StyleSheet } from 'react-native';
import { ClassificationScreen } from './screens/ClassificationScreen';

enum ModelType {
STYLE_TRANSFER,
OBJECT_DETECTION,
IMAGE_CLASSIFICATION,
CLASSIFICATION,
SEMANTIC_SEGMENTATION,
}

Expand All @@ -36,8 +37,10 @@ export default function App() {
);
case ModelType.OBJECT_DETECTION:
return <></>;
case ModelType.IMAGE_CLASSIFICATION:
return <></>;
case ModelType.CLASSIFICATION:
return (
<ClassificationScreen imageUri={imageUri} setImageUri={setImageUri} />
);
case ModelType.SEMANTIC_SEGMENTATION:
return <></>;
default:
Expand All @@ -57,17 +60,17 @@ export default function App() {
dataSource={[
'Style Transfer',
'Object Detection',
'Image Classification',
'Classification',
'Semantic Segmentation',
]}
onValueChange={(_, selectedIndex) => {
handleModeChange(selectedIndex);
}}
wrapperHeight={135}
wrapperHeight={100}
highlightColor={ColorPalette.primary}
wrapperBackground="#fff"
highlightBorderWidth={3}
itemHeight={60}
itemHeight={40}
activeItemTextStyle={styles.activeScrollItem}
/>
</View>
Expand All @@ -85,15 +88,15 @@ const styles = StyleSheet.create({
},
topContainer: {
marginTop: 5,
height: 150,
height: 145,
width: '100%',
alignItems: 'center',
justifyContent: 'center',
marginBottom: 16,
},
wheelPickerContainer: {
width: '100%',
height: 135,
height: 100,
},
activeScrollItem: {
color: ColorPalette.primary,
Expand Down
Binary file not shown.
Binary file not shown.
8 changes: 6 additions & 2 deletions examples/computer-vision/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ PODS:
- hermes-engine (0.76.3):
- hermes-engine/Pre-built (= 0.76.3)
- hermes-engine/Pre-built (0.76.3)
- opencv-rne (0.1.0)
- RCT-Folly (2024.01.01.00):
- boost
- DoubleConversion
Expand Down Expand Up @@ -1277,10 +1278,11 @@ PODS:
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- Yoga
- react-native-executorch (0.1.524):
- react-native-executorch (0.1.100):
- DoubleConversion
- glog
- hermes-engine
- opencv-rne (~> 0.1.0)
- RCT-Folly (= 2024.01.01.00)
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -1866,6 +1868,7 @@ DEPENDENCIES:

SPEC REPOS:
trunk:
- opencv-rne
- SocketRocket

EXTERNAL SOURCES:
Expand Down Expand Up @@ -2035,6 +2038,7 @@ SPEC CHECKSUMS:
fmt: 10c6e61f4be25dc963c36bd73fc7b1705fe975be
glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a
hermes-engine: 0555a84ea495e8e3b4bde71b597cd87fbb382888
opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f
RCT-Folly: bf5c0376ffe4dd2cf438dcf86db385df9fdce648
RCTDeprecation: 2c5e1000b04ab70b53956aa498bf7442c3c6e497
RCTRequired: 5f785a001cf68a551c5f5040fb4c415672dbb481
Expand Down Expand Up @@ -2064,7 +2068,7 @@ SPEC CHECKSUMS:
React-logger: 26155dc23db5c9038794db915f80bd2044512c2e
React-Mapbuffer: ad1ba0205205a16dbff11b8ade6d1b3959451658
React-microtasksnativemodule: e771eb9eb6ace5884ee40a293a0e14a9d7a4343c
react-native-executorch: 9782e20c5bb4ddf7836af8779f887223228833e9
react-native-executorch: a30dd907f470d5c4f8135e2ba1fa2a3bb65ea47a
react-native-image-picker: bfb56e2b39dc63abfcc6de44ee239c6633f47d66
react-native-safe-area-context: d6406c2adbd41b2e09ab1c386781dc1c81a90919
React-nativeconfig: aeed6e2a8ac02b2df54476afcc7c663416c12bf7
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@
);
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_DEBUG";
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
PRODUCT_NAME = "computervision";
PRODUCT_NAME = computervision;
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
Expand Down Expand Up @@ -399,7 +399,7 @@
);
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_RELEASE";
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
PRODUCT_NAME = "computervision";
PRODUCT_NAME = computervision;
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
Expand Down
6 changes: 6 additions & 0 deletions examples/computer-vision/models/classification.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { Platform } from 'react-native';

export const efficientnet_v2_s =
Platform.OS === 'ios'
? require('../assets/classification/ios/efficientnet_v2_s_coreml_all.pte')
: require('../assets/classification/android/efficientnet_v2_s_xnnpack.pte');
126 changes: 126 additions & 0 deletions examples/computer-vision/screens/ClassificationScreen.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { useState } from 'react';
import Spinner from 'react-native-loading-spinner-overlay';
import { BottomBar } from '../components/BottomBar';
import { efficientnet_v2_s } from '../models/classification';
import { getImageUri } from '../utils';
import { useClassification } from 'react-native-executorch';
import { View, StyleSheet, Image, Text, ScrollView } from 'react-native';

export const ClassificationScreen = ({
imageUri,
setImageUri,
}: {
imageUri: string;
setImageUri: (imageUri: string) => void;
}) => {
const [results, setResults] = useState<{ label: string; score: number }[]>(
[]
);

const model = useClassification({
modulePath: efficientnet_v2_s,
});

const handleCameraPress = async (isCamera: boolean) => {
const uri = await getImageUri(isCamera);
if (typeof uri === 'string') {
setImageUri(uri as string);
setResults([]);
}
};

const runForward = async () => {
if (imageUri) {
try {
const output = await model.forward(imageUri);
const top10 = Object.entries(output)
.sort(([, a], [, b]) => b - a)
.slice(0, 10)
.map(([label, score]) => ({ label, score }));
setResults(top10);
} catch (e) {
console.error(e);
}
}
};

if (!model.isModelReady) {
return (
<Spinner
visible={!model.isModelReady}
textContent={`Loading the model...`}
/>
);
}

return (
<>
<View style={styles.imageContainer}>
<Image
style={styles.image}
resizeMode="contain"
source={
imageUri
? { uri: imageUri }
: require('../assets/icons/executorch_logo.png')
}
/>
{results.length > 0 && (
<View style={styles.results}>
<Text style={styles.resultHeader}>Results Top 10</Text>
<ScrollView style={styles.resultsList}>
{results.map(({ label, score }) => (
<View key={label} style={styles.resultRecord}>
<Text style={styles.resultLabel}>{label}</Text>
<Text>{score.toFixed(3)}</Text>
</View>
))}
</ScrollView>
</View>
)}
</View>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
/>
</>
);
};

const styles = StyleSheet.create({
imageContainer: {
flex: 6,
width: '100%',
padding: 16,
},
image: {
flex: 2,
borderRadius: 8,
width: '100%',
},
results: {
flex: 1,
alignItems: 'center',
justifyContent: 'center',
gap: 4,
padding: 4,
},
resultHeader: {
fontSize: 18,
color: 'navy',
},
resultsList: {
flex: 1,
},
resultRecord: {
flexDirection: 'row',
width: '100%',
justifyContent: 'space-between',
padding: 8,
borderBottomWidth: 1,
},
resultLabel: {
flex: 1,
marginRight: 4,
},
});
10 changes: 5 additions & 5 deletions examples/computer-vision/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3352,7 +3352,7 @@ __metadata:
metro-config: ^0.81.0
react: 18.3.1
react-native: 0.76.3
react-native-executorch: /Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz
react-native-executorch: ../../react-native-executorch-0.1.100.tgz
react-native-image-picker: ^7.2.2
react-native-loading-spinner-overlay: ^3.0.1
react-native-reanimated: ^3.16.3
Expand Down Expand Up @@ -6799,13 +6799,13 @@ __metadata:
languageName: node
linkType: hard

"react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A.":
version: 0.1.524
resolution: "react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A."
"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.":
version: 0.1.100
resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A."
peerDependencies:
react: "*"
react-native: "*"
checksum: 4f67dbd81711997e5f890b2d7c7b025777d6bcf71a335c7efa56da3fa59a6d00a4915c8cb63e4362d931885adde4072fad63bfbabb93bc36edf116e4fa98f5b9
checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03
languageName: node
linkType: hard

Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/Classification.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface Classification : NSObject <NativeClassificationSpec>

@end
56 changes: 56 additions & 0 deletions ios/RnExecutorch/Classification.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#import "Classification.h"
#import "utils/Fetcher.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import "ImageProcessor.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import "models/classification/ClassificationModel.h"
#import "opencv2/opencv.hpp"

@implementation Classification {
ClassificationModel* model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
model = [[ClassificationModel alloc] init];
[model loadModel: [NSURL URLWithString:modelSource] completion:^(BOOL success, NSNumber *errorCode){
if(success){
resolve(errorCode);
return;
}

reject(@"init_module_error", [NSString
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil);
return;
}];
}

- (void)forward:(NSString *)input
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
cv::Mat image = [ImageProcessor readImage:input];
NSDictionary *result = [model runModel:image];

resolve(result);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
return;
}
}


- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeClassificationSpecJSI>(params);
}

@end
3 changes: 1 addition & 2 deletions ios/RnExecutorch/models/StyleTransferModel.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#import <UIKit/UIKit.h>
#import "BaseModel.h"
#import <opencv2/opencv.hpp>
#import "opencv2/opencv.hpp"

@interface StyleTransferModel : BaseModel

Expand Down
10 changes: 10 additions & 0 deletions ios/RnExecutorch/models/classification/ClassificationModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "BaseModel.h"
#import "opencv2/opencv.hpp"

@interface ClassificationModel : BaseModel

- (NSArray *)preprocess:(cv::Mat &)input;
- (NSDictionary *)runModel:(cv::Mat &)input;
- (NSDictionary *)postprocess:(NSArray *)output;

@end
Loading

0 comments on commit 7484a31

Please sign in to comment.