Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@jakmro/classification ios #54

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/computer-vision/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import { useState } from 'react';
import { StyleTransferScreen } from './screens/StyleTransferScreen';
import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
import { View, StyleSheet } from 'react-native';
import { ClassificationScreen } from './screens/ClassificationScreen';

enum ModelType {
STYLE_TRANSFER,
OBJECT_DETECTION,
IMAGE_CLASSIFICATION,
CLASSIFICATION,
SEMANTIC_SEGMENTATION,
}

Expand All @@ -36,8 +37,10 @@ export default function App() {
);
case ModelType.OBJECT_DETECTION:
return <></>;
case ModelType.IMAGE_CLASSIFICATION:
return <></>;
case ModelType.CLASSIFICATION:
return (
<ClassificationScreen imageUri={imageUri} setImageUri={setImageUri} />
);
case ModelType.SEMANTIC_SEGMENTATION:
return <></>;
default:
Expand All @@ -57,17 +60,17 @@ export default function App() {
dataSource={[
'Style Transfer',
'Object Detection',
'Image Classification',
'Classification',
'Semantic Segmentation',
]}
onValueChange={(_, selectedIndex) => {
handleModeChange(selectedIndex);
}}
wrapperHeight={135}
wrapperHeight={100}
highlightColor={ColorPalette.primary}
wrapperBackground="#fff"
highlightBorderWidth={3}
itemHeight={60}
itemHeight={40}
activeItemTextStyle={styles.activeScrollItem}
/>
</View>
Expand All @@ -85,15 +88,15 @@ const styles = StyleSheet.create({
},
topContainer: {
marginTop: 5,
height: 150,
height: 145,
width: '100%',
alignItems: 'center',
justifyContent: 'center',
marginBottom: 16,
},
wheelPickerContainer: {
width: '100%',
height: 135,
height: 100,
},
activeScrollItem: {
color: ColorPalette.primary,
Expand Down
Binary file not shown.
Binary file not shown.
8 changes: 6 additions & 2 deletions examples/computer-vision/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ PODS:
- hermes-engine (0.76.3):
- hermes-engine/Pre-built (= 0.76.3)
- hermes-engine/Pre-built (0.76.3)
- opencv-rne (0.1.0)
- RCT-Folly (2024.01.01.00):
- boost
- DoubleConversion
Expand Down Expand Up @@ -1277,10 +1278,11 @@ PODS:
- ReactCommon/turbomodule/bridging
- ReactCommon/turbomodule/core
- Yoga
- react-native-executorch (0.1.524):
- react-native-executorch (0.1.100):
- DoubleConversion
- glog
- hermes-engine
- opencv-rne (~> 0.1.0)
- RCT-Folly (= 2024.01.01.00)
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -1866,6 +1868,7 @@ DEPENDENCIES:

SPEC REPOS:
trunk:
- opencv-rne
- SocketRocket

EXTERNAL SOURCES:
Expand Down Expand Up @@ -2035,6 +2038,7 @@ SPEC CHECKSUMS:
fmt: 10c6e61f4be25dc963c36bd73fc7b1705fe975be
glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a
hermes-engine: 0555a84ea495e8e3b4bde71b597cd87fbb382888
opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f
RCT-Folly: bf5c0376ffe4dd2cf438dcf86db385df9fdce648
RCTDeprecation: 2c5e1000b04ab70b53956aa498bf7442c3c6e497
RCTRequired: 5f785a001cf68a551c5f5040fb4c415672dbb481
Expand Down Expand Up @@ -2064,7 +2068,7 @@ SPEC CHECKSUMS:
React-logger: 26155dc23db5c9038794db915f80bd2044512c2e
React-Mapbuffer: ad1ba0205205a16dbff11b8ade6d1b3959451658
React-microtasksnativemodule: e771eb9eb6ace5884ee40a293a0e14a9d7a4343c
react-native-executorch: 9782e20c5bb4ddf7836af8779f887223228833e9
react-native-executorch: a30dd907f470d5c4f8135e2ba1fa2a3bb65ea47a
react-native-image-picker: bfb56e2b39dc63abfcc6de44ee239c6633f47d66
react-native-safe-area-context: d6406c2adbd41b2e09ab1c386781dc1c81a90919
React-nativeconfig: aeed6e2a8ac02b2df54476afcc7c663416c12bf7
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@
);
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_DEBUG";
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
PRODUCT_NAME = "computervision";
PRODUCT_NAME = computervision;
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
Expand Down Expand Up @@ -399,7 +399,7 @@
);
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_RELEASE";
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
PRODUCT_NAME = "computervision";
PRODUCT_NAME = computervision;
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
Expand Down
6 changes: 6 additions & 0 deletions examples/computer-vision/models/classification.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { Platform } from 'react-native';

export const efficientnet_v2_s =
Platform.OS === 'ios'
? require('../assets/classification/ios/efficientnet_v2_s_coreml_all.pte')
: require('../assets/classification/android/efficientnet_v2_s_xnnpack.pte');
126 changes: 126 additions & 0 deletions examples/computer-vision/screens/ClassificationScreen.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { useState } from 'react';
import Spinner from 'react-native-loading-spinner-overlay';
import { BottomBar } from '../components/BottomBar';
import { efficientnet_v2_s } from '../models/classification';
import { getImageUri } from '../utils';
import { useClassification } from 'react-native-executorch';
import { View, StyleSheet, Image, Text, ScrollView } from 'react-native';

export const ClassificationScreen = ({
imageUri,
setImageUri,
}: {
imageUri: string;
setImageUri: (imageUri: string) => void;
}) => {
const [results, setResults] = useState<{ label: string; score: number }[]>(
[]
);

const model = useClassification({
modulePath: efficientnet_v2_s,
});

const handleCameraPress = async (isCamera: boolean) => {
const uri = await getImageUri(isCamera);
if (typeof uri === 'string') {
setImageUri(uri as string);
setResults([]);
}
};

const runForward = async () => {
if (imageUri) {
try {
const output = await model.forward(imageUri);
const top10 = Object.entries(output)
.sort(([, a], [, b]) => b - a)
.slice(0, 10)
.map(([label, score]) => ({ label, score }));
setResults(top10);
} catch (e) {
console.error(e);
}
}
};

if (!model.isModelReady) {
return (
<Spinner
visible={!model.isModelReady}
textContent={`Loading the model...`}
/>
);
}

return (
<>
<View style={styles.imageContainer}>
<Image
style={styles.image}
resizeMode="contain"
source={
imageUri
? { uri: imageUri }
: require('../assets/icons/executorch_logo.png')
}
/>
{results.length > 0 && (
<View style={styles.results}>
<Text style={styles.resultHeader}>Results Top 10</Text>
<ScrollView style={styles.resultsList}>
{results.map(({ label, score }) => (
<View key={label} style={styles.resultRecord}>
<Text style={styles.resultLabel}>{label}</Text>
<Text>{score.toFixed(3)}</Text>
</View>
))}
</ScrollView>
</View>
)}
</View>
<BottomBar
handleCameraPress={handleCameraPress}
runForward={runForward}
/>
</>
);
};

const styles = StyleSheet.create({
imageContainer: {
flex: 6,
width: '100%',
padding: 16,
},
image: {
flex: 2,
borderRadius: 8,
width: '100%',
},
results: {
flex: 1,
alignItems: 'center',
justifyContent: 'center',
gap: 4,
padding: 4,
},
resultHeader: {
fontSize: 18,
color: 'navy',
},
resultsList: {
flex: 1,
},
resultRecord: {
flexDirection: 'row',
width: '100%',
justifyContent: 'space-between',
padding: 8,
borderBottomWidth: 1,
},
resultLabel: {
flex: 1,
marginRight: 4,
},
});
10 changes: 5 additions & 5 deletions examples/computer-vision/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3352,7 +3352,7 @@ __metadata:
metro-config: ^0.81.0
react: 18.3.1
react-native: 0.76.3
react-native-executorch: /Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz
react-native-executorch: ../../react-native-executorch-0.1.100.tgz
react-native-image-picker: ^7.2.2
react-native-loading-spinner-overlay: ^3.0.1
react-native-reanimated: ^3.16.3
Expand Down Expand Up @@ -6799,13 +6799,13 @@ __metadata:
languageName: node
linkType: hard

"react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A.":
version: 0.1.524
resolution: "react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A."
"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.":
version: 0.1.100
resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A."
peerDependencies:
react: "*"
react-native: "*"
checksum: 4f67dbd81711997e5f890b2d7c7b025777d6bcf71a335c7efa56da3fa59a6d00a4915c8cb63e4362d931885adde4072fad63bfbabb93bc36edf116e4fa98f5b9
checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03
languageName: node
linkType: hard

Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/Classification.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface Classification : NSObject <NativeClassificationSpec>

@end
56 changes: 56 additions & 0 deletions ios/RnExecutorch/Classification.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#import "Classification.h"
#import "utils/Fetcher.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import "ImageProcessor.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import "models/classification/ClassificationModel.h"
#import "opencv2/opencv.hpp"

@implementation Classification {
ClassificationModel* model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
model = [[ClassificationModel alloc] init];
[model loadModel: [NSURL URLWithString:modelSource] completion:^(BOOL success, NSNumber *errorCode){
if(success){
resolve(errorCode);
return;
}

reject(@"init_module_error", [NSString
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil);
return;
}];
}

- (void)forward:(NSString *)input
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
cv::Mat image = [ImageProcessor readImage:input];
NSDictionary *result = [model runModel:image];

resolve(result);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
return;
}
}


- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeClassificationSpecJSI>(params);
}

@end
3 changes: 1 addition & 2 deletions ios/RnExecutorch/models/StyleTransferModel.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#import <UIKit/UIKit.h>
#import "BaseModel.h"
#import <opencv2/opencv.hpp>
#import "opencv2/opencv.hpp"

@interface StyleTransferModel : BaseModel

Expand Down
10 changes: 10 additions & 0 deletions ios/RnExecutorch/models/classification/ClassificationModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "BaseModel.h"
#import "opencv2/opencv.hpp"

@interface ClassificationModel : BaseModel

- (NSArray *)preprocess:(cv::Mat &)input;
- (NSDictionary *)runModel:(cv::Mat &)input;
- (NSDictionary *)postprocess:(NSArray *)output;

@end
Loading
Loading