-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Description Image classification for iOS ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
- Loading branch information
Showing
21 changed files
with
1,412 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+81.7 MB
examples/computer-vision/assets/classification/android/efficientnet_v2_s_xnnpack.pte
Binary file not shown.
Binary file added
BIN
+41.9 MB
examples/computer-vision/assets/classification/ios/efficientnet_v2_s_coreml_all.pte
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import { Platform } from 'react-native'; | ||
|
||
export const efficientnet_v2_s = | ||
Platform.OS === 'ios' | ||
? require('../assets/classification/ios/efficientnet_v2_s_coreml_all.pte') | ||
: require('../assets/classification/android/efficientnet_v2_s_xnnpack.pte'); |
126 changes: 126 additions & 0 deletions
126
examples/computer-vision/screens/ClassificationScreen.tsx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import { useState } from 'react'; | ||
import Spinner from 'react-native-loading-spinner-overlay'; | ||
import { BottomBar } from '../components/BottomBar'; | ||
import { efficientnet_v2_s } from '../models/classification'; | ||
import { getImageUri } from '../utils'; | ||
import { useClassification } from 'react-native-executorch'; | ||
import { View, StyleSheet, Image, Text, ScrollView } from 'react-native'; | ||
|
||
export const ClassificationScreen = ({ | ||
imageUri, | ||
setImageUri, | ||
}: { | ||
imageUri: string; | ||
setImageUri: (imageUri: string) => void; | ||
}) => { | ||
const [results, setResults] = useState<{ label: string; score: number }[]>( | ||
[] | ||
); | ||
|
||
const model = useClassification({ | ||
modulePath: efficientnet_v2_s, | ||
}); | ||
|
||
const handleCameraPress = async (isCamera: boolean) => { | ||
const uri = await getImageUri(isCamera); | ||
if (typeof uri === 'string') { | ||
setImageUri(uri as string); | ||
setResults([]); | ||
} | ||
}; | ||
|
||
const runForward = async () => { | ||
if (imageUri) { | ||
try { | ||
const output = await model.forward(imageUri); | ||
const top10 = Object.entries(output) | ||
.sort(([, a], [, b]) => b - a) | ||
.slice(0, 10) | ||
.map(([label, score]) => ({ label, score })); | ||
setResults(top10); | ||
} catch (e) { | ||
console.error(e); | ||
} | ||
} | ||
}; | ||
|
||
if (!model.isModelReady) { | ||
return ( | ||
<Spinner | ||
visible={!model.isModelReady} | ||
textContent={`Loading the model...`} | ||
/> | ||
); | ||
} | ||
|
||
return ( | ||
<> | ||
<View style={styles.imageContainer}> | ||
<Image | ||
style={styles.image} | ||
resizeMode="contain" | ||
source={ | ||
imageUri | ||
? { uri: imageUri } | ||
: require('../assets/icons/executorch_logo.png') | ||
} | ||
/> | ||
{results.length > 0 && ( | ||
<View style={styles.results}> | ||
<Text style={styles.resultHeader}>Results Top 10</Text> | ||
<ScrollView style={styles.resultsList}> | ||
{results.map(({ label, score }) => ( | ||
<View key={label} style={styles.resultRecord}> | ||
<Text style={styles.resultLabel}>{label}</Text> | ||
<Text>{score.toFixed(3)}</Text> | ||
</View> | ||
))} | ||
</ScrollView> | ||
</View> | ||
)} | ||
</View> | ||
<BottomBar | ||
handleCameraPress={handleCameraPress} | ||
runForward={runForward} | ||
/> | ||
</> | ||
); | ||
}; | ||
|
||
const styles = StyleSheet.create({ | ||
imageContainer: { | ||
flex: 6, | ||
width: '100%', | ||
padding: 16, | ||
}, | ||
image: { | ||
flex: 2, | ||
borderRadius: 8, | ||
width: '100%', | ||
}, | ||
results: { | ||
flex: 1, | ||
alignItems: 'center', | ||
justifyContent: 'center', | ||
gap: 4, | ||
padding: 4, | ||
}, | ||
resultHeader: { | ||
fontSize: 18, | ||
color: 'navy', | ||
}, | ||
resultsList: { | ||
flex: 1, | ||
}, | ||
resultRecord: { | ||
flexDirection: 'row', | ||
width: '100%', | ||
justifyContent: 'space-between', | ||
padding: 8, | ||
borderBottomWidth: 1, | ||
}, | ||
resultLabel: { | ||
flex: 1, | ||
marginRight: 4, | ||
}, | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#import <RnExecutorchSpec/RnExecutorchSpec.h> | ||
|
||
@interface Classification : NSObject <NativeClassificationSpec> | ||
|
||
@end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#import "Classification.h" | ||
#import "utils/Fetcher.h" | ||
#import "models/BaseModel.h" | ||
#import "utils/ETError.h" | ||
#import "ImageProcessor.h" | ||
#import <ExecutorchLib/ETModel.h> | ||
#import <React/RCTBridgeModule.h> | ||
#import "models/classification/ClassificationModel.h" | ||
#import "opencv2/opencv.hpp" | ||
|
||
@implementation Classification { | ||
ClassificationModel* model; | ||
} | ||
|
||
RCT_EXPORT_MODULE() | ||
|
||
- (void)loadModule:(NSString *)modelSource | ||
resolve:(RCTPromiseResolveBlock)resolve | ||
reject:(RCTPromiseRejectBlock)reject { | ||
model = [[ClassificationModel alloc] init]; | ||
[model loadModel: [NSURL URLWithString:modelSource] completion:^(BOOL success, NSNumber *errorCode){ | ||
if(success){ | ||
resolve(errorCode); | ||
return; | ||
} | ||
|
||
reject(@"init_module_error", [NSString | ||
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil); | ||
return; | ||
}]; | ||
} | ||
|
||
- (void)forward:(NSString *)input | ||
resolve:(RCTPromiseResolveBlock)resolve | ||
reject:(RCTPromiseRejectBlock)reject { | ||
@try { | ||
cv::Mat image = [ImageProcessor readImage:input]; | ||
NSDictionary *result = [model runModel:image]; | ||
|
||
resolve(result); | ||
return; | ||
} @catch (NSException *exception) { | ||
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason); | ||
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], | ||
nil); | ||
return; | ||
} | ||
} | ||
|
||
|
||
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule: | ||
(const facebook::react::ObjCTurboModule::InitParams &)params { | ||
return std::make_shared<facebook::react::NativeClassificationSpecJSI>(params); | ||
} | ||
|
||
@end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
ios/RnExecutorch/models/classification/ClassificationModel.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#import "BaseModel.h" | ||
#import "opencv2/opencv.hpp" | ||
|
||
@interface ClassificationModel : BaseModel | ||
|
||
- (NSArray *)preprocess:(cv::Mat &)input; | ||
- (NSDictionary *)runModel:(cv::Mat &)input; | ||
- (NSDictionary *)postprocess:(NSArray *)output; | ||
|
||
@end |
Oops, something went wrong.