diff --git a/android/src/main/java/com/swmansion/rnexecutorch/Classification.kt b/android/src/main/java/com/swmansion/rnexecutorch/Classification.kt index eb573c7..3967a08 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/Classification.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/Classification.kt @@ -32,7 +32,7 @@ class Classification(reactContext: ReactApplicationContext) : classificationModel.loadModel(modelSource) promise.resolve(0) } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelPath.toString()) + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt b/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt index 30150b1..b3a57f2 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt @@ -17,13 +17,13 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react return NAME } - override fun loadModule(modelPath: String, promise: Promise) { + override fun loadModule(modelSource: String, promise: Promise) { Fetcher.downloadModel( reactApplicationContext, - modelPath, + modelSource, ) { path, error -> if (error != null) { - promise.reject(error.message!!, ETError.InvalidModelPath.toString()) + promise.reject(error.message!!, ETError.InvalidModelSource.toString()) return@downloadModel } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt index ea4ecdf..ed4e4a2 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt @@ -31,7 +31,7 @@ class StyleTransfer(reactContext: ReactApplicationContext) : styleTransferModel.loadModel(modelSource) promise.resolve(0) } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelPath.toString()) + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt index 6cf2d03..ce335f1 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt @@ -4,7 +4,7 @@ enum class ETError(val code: Int) { UndefinedError(0x65), ModuleNotLoaded(0x66), FileWriteFailed(0x67), - InvalidModelPath(0xff), + InvalidModelSource(0xff), // System errors Ok(0x00), diff --git a/docs/docs/guides/running-llms.md b/docs/docs/guides/running-llms.md index eb3f0e8..a2df135 100644 --- a/docs/docs/guides/running-llms.md +++ b/docs/docs/guides/running-llms.md @@ -14,10 +14,10 @@ React Native ExecuTorch supports Llama 3.2 models, including quantized versions. In order to load a model into the app, you need to run the following code: ```typescript -import { useLLM, LLAMA3_2_1B_URL } from 'react-native-executorch'; +import { useLLM, LLAMA3_2_1B } from 'react-native-executorch'; const llama = useLLM({ - modelSource: LLAMA3_2_1B_URL, + modelSource: LLAMA3_2_1B, tokenizer: require('../assets/tokenizer.bin'), contextWindowLength: 3, }); @@ -91,7 +91,7 @@ In order to send a message to the model, one can use the following code: ```typescript const llama = useLLM( - modelSource: LLAMA3_2_1B_URL, + modelSource: LLAMA3_2_1B, tokenizer: require('../assets/tokenizer.bin'), ); diff --git a/examples/computer-vision/screens/ClassificationScreen.tsx b/examples/computer-vision/screens/ClassificationScreen.tsx index 587d072..3434c0d 100644 --- a/examples/computer-vision/screens/ClassificationScreen.tsx +++ b/examples/computer-vision/screens/ClassificationScreen.tsx @@ -17,7 +17,7 @@ export const ClassificationScreen = ({ ); const model = useClassification({ - modulePath: EFFICIENTNET_V2_S, + modelSource: EFFICIENTNET_V2_S, }); const handleCameraPress = async (isCamera: boolean) => { diff --git a/examples/computer-vision/screens/ObjectDetectionScreen.tsx b/examples/computer-vision/screens/ObjectDetectionScreen.tsx index 67f32e7..2e6acaf 100644 --- a/examples/computer-vision/screens/ObjectDetectionScreen.tsx +++ b/examples/computer-vision/screens/ObjectDetectionScreen.tsx @@ -5,7 +5,7 @@ import { getImage } from '../utils'; import { Detection, useObjectDetection, - SSDLITE_320_MOBILENET_V3_LARGE_URL, + SSDLITE_320_MOBILENET_V3_LARGE, } from 'react-native-executorch'; import { View, StyleSheet, Image } from 'react-native'; import ImageWithBboxes from '../components/ImageWithBboxes'; @@ -24,7 +24,7 @@ export const ObjectDetectionScreen = ({ }>(); const ssdLite = useObjectDetection({ - modelSource: SSDLITE_320_MOBILENET_V3_LARGE_URL, + modelSource: SSDLITE_320_MOBILENET_V3_LARGE, }); const handleCameraPress = async (isCamera: boolean) => { diff --git a/examples/computer-vision/screens/StyleTransferScreen.tsx b/examples/computer-vision/screens/StyleTransferScreen.tsx index a82657b..633e329 100644 --- a/examples/computer-vision/screens/StyleTransferScreen.tsx +++ b/examples/computer-vision/screens/StyleTransferScreen.tsx @@ -15,7 +15,7 @@ export const StyleTransferScreen = ({ setImageUri: (imageUri: string) => void; }) => { const model = useStyleTransfer({ - modulePath: STYLE_TRANSFER_CANDY, + modelSource: STYLE_TRANSFER_CANDY, }); const handleCameraPress = async (isCamera: boolean) => { diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx index ea656d2..10f18ab 100644 --- a/examples/llama/screens/ChatScreen.tsx +++ b/examples/llama/screens/ChatScreen.tsx @@ -14,7 +14,7 @@ import { SafeAreaView } from 'react-native-safe-area-context'; import SWMIcon from '../assets/icons/swm_icon.svg'; import SendIcon from '../assets/icons/send_icon.svg'; import Spinner from 'react-native-loading-spinner-overlay'; -import { LLAMA3_2_1B_QLORA_URL, useLLM } from 'react-native-executorch'; +import { LLAMA3_2_1B_QLORA, useLLM } from 'react-native-executorch'; import PauseIcon from '../assets/icons/pause_icon.svg'; import ColorPalette from '../colors'; import Messages from '../components/Messages'; @@ -25,7 +25,7 @@ export default function ChatScreen() { const [isTextInputFocused, setIsTextInputFocused] = useState(false); const [userInput, setUserInput] = useState(''); const llama = useLLM({ - modelSource: LLAMA3_2_1B_QLORA_URL, + modelSource: LLAMA3_2_1B_QLORA, tokenizerSource: require('../assets/tokenizer.bin'), contextWindowLength: 6, }); diff --git a/ios/RnExecutorch/models/BaseModel.mm b/ios/RnExecutorch/models/BaseModel.mm index 76ee31f..b06a353 100644 --- a/ios/RnExecutorch/models/BaseModel.mm +++ b/ios/RnExecutorch/models/BaseModel.mm @@ -13,7 +13,7 @@ - (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber* module = [[ETModel alloc] init]; [Fetcher fetchResource:modelURL resourceType:ResourceType::MODEL completionHandler:^(NSString *filePath, NSError *error) { if (error) { - completion(NO, @(InvalidModelPath)); + completion(NO, @(InvalidModelSource)); return; } NSNumber *result = [self->module loadModel: filePath]; diff --git a/ios/RnExecutorch/utils/ETError.h b/ios/RnExecutorch/utils/ETError.h index 092e79d..f1394a0 100644 --- a/ios/RnExecutorch/utils/ETError.h +++ b/ios/RnExecutorch/utils/ETError.h @@ -2,7 +2,7 @@ typedef NS_ENUM(NSUInteger, ETError) { UndefinedError = 0x65, ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, - InvalidModelPath = 0xff, + InvalidModelSource = 0xff, Ok = 0x00, Internal = 0x01, diff --git a/src/ETModule.ts b/src/ETModule.ts index 0d99ef5..be32118 100644 --- a/src/ETModule.ts +++ b/src/ETModule.ts @@ -15,11 +15,11 @@ const getTypeIdentifier = (arr: ETInput): number => { }; interface Props { - modulePath: string | number; + modelSource: string | number; } export const useExecutorchModule = ({ - modulePath, + modelSource, }: Props): ExecutorchModule => { const [error, setError] = useState(null); const [isModelLoading, setIsModelLoading] = useState(true); @@ -27,9 +27,9 @@ export const useExecutorchModule = ({ useEffect(() => { const loadModel = async () => { - let path = modulePath; - if (typeof modulePath === 'number') { - path = Image.resolveAssetSource(modulePath).uri; + let path = modelSource; + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; } try { @@ -42,7 +42,7 @@ export const useExecutorchModule = ({ } }; loadModel(); - }, [modulePath]); + }, [modelSource]); const forward = async (input: ETInput, shape: number[]) => { if (isModelLoading) { diff --git a/src/Error.ts b/src/Error.ts index 24e6c32..7678563 100644 --- a/src/Error.ts +++ b/src/Error.ts @@ -4,7 +4,7 @@ export enum ETError { ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, ModelGenerating = 0x68, - InvalidModelPath = 0xff, + InvalidModelSource = 0xff, // ExecuTorch mapped errors // Based on: https://github.com/pytorch/executorch/blob/main/runtime/core/error.h diff --git a/src/StyleTransfer.ts b/src/StyleTransfer.ts index 7496aab..16fe521 100644 --- a/src/StyleTransfer.ts +++ b/src/StyleTransfer.ts @@ -4,7 +4,7 @@ import { StyleTransfer } from './native/RnExecutorchModules'; import { ETError, getError } from './Error'; interface Props { - modulePath: string | number; + modelSource: string | number; } interface StyleTransferModule { @@ -15,7 +15,7 @@ interface StyleTransferModule { } export const useStyleTransfer = ({ - modulePath, + modelSource, }: Props): StyleTransferModule => { const [error, setError] = useState(null); const [isModelReady, setIsModelReady] = useState(false); @@ -23,10 +23,10 @@ export const useStyleTransfer = ({ useEffect(() => { const loadModel = async () => { - let path = modulePath; + let path = modelSource; - if (typeof modulePath === 'number') { - path = Image.resolveAssetSource(modulePath).uri; + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; } try { @@ -39,7 +39,7 @@ export const useStyleTransfer = ({ }; loadModel(); - }, [modulePath]); + }, [modelSource]); const forward = async (input: string) => { if (!isModelReady) { diff --git a/src/constants/modelUrls.ts b/src/constants/modelUrls.ts index dca54e4..2f57331 100644 --- a/src/constants/modelUrls.ts +++ b/src/constants/modelUrls.ts @@ -1,17 +1,17 @@ import { Platform } from 'react-native'; // LLM's -export const LLAMA3_2_3B_URL = +export const LLAMA3_2_3B = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/original/llama3_2_3B_bf16.pte'; -export const LLAMA3_2_3B_QLORA_URL = +export const LLAMA3_2_3B_QLORA = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/QLoRA/llama3_2-3B_qat_lora.pte'; -export const LLAMA3_2_3B_SPINQUANT_URL = +export const LLAMA3_2_3B_SPINQUANT = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/spinquant/llama3_2_3B_spinquant.pte'; -export const LLAMA3_2_1B_URL = +export const LLAMA3_2_1B = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/llama3_2_bf16.pte'; -export const LLAMA3_2_1B_QLORA_URL = +export const LLAMA3_2_1B_QLORA = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/QLoRA/llama3_2_qat_lora.pte'; -export const LLAMA3_2_1B_SPINQUANT_URL = +export const LLAMA3_2_1B_SPINQUANT = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/spinquant/llama3_2_spinquant.pte'; export const LLAMA3_2_1B_TOKENIZER = 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/tokenizer.bin'; @@ -24,26 +24,32 @@ export const EFFICIENTNET_V2_S = ? 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.2.0/coreml/efficientnet_v2_s_coreml_all.pte' : 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.2.0/xnnpack/efficientnet_v2_s_xnnpack.pte'; +// Object detection +export const SSDLITE_320_MOBILENET_V3_LARGE = + 'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/v0.2.0/ssdlite320-mobilenetv3-large.pte'; + // Style transfer export const STYLE_TRANSFER_CANDY = Platform.OS === 'ios' ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/coreml/style_transfer_candy_coreml.pte' : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/xnnpack/style_transfer_candy_xnnpack.pte'; - export const STYLE_TRANSFER_MOSAIC = Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/coreml/style_transfer_mosaic_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/xnnpack/style_transfer_mosaic_xnnpack.pte'; - + ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/coreml/style_transfer_mosaic_coreml.pte' + : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/xnnpack/style_transfer_mosaic_xnnpack.pte'; export const STYLE_TRANSFER_RAIN_PRINCESS = Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/coreml/style_transfer_rain_princess_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/xnnpack/style_transfer_rain_princess_xnnpack.pte'; + ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/coreml/style_transfer_rain_princess_coreml.pte' + : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/xnnpack/style_transfer_rain_princess_xnnpack.pte'; export const STYLE_TRANSFER_UDNIE = Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/coreml/style_transfer_udnie_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/xnnpack/style_transfer_udnie_xnnpack.pte'; + ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/coreml/style_transfer_udnie_coreml.pte' + : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/xnnpack/style_transfer_udnie_xnnpack.pte'; -// Object detection -export const SSDLITE_320_MOBILENET_V3_LARGE_URL = - 'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/main/ssdlite320-mobilenetv3-large.pte'; +// Backward compatibility +export const LLAMA3_2_3B_URL = LLAMA3_2_3B; +export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA; +export const LLAMA3_2_3B_SPINQUANT_URL = LLAMA3_2_3B_SPINQUANT; +export const LLAMA3_2_1B_URL = LLAMA3_2_1B; +export const LLAMA3_2_1B_QLORA_URL = LLAMA3_2_1B_QLORA; +export const LLAMA3_2_1B_SPINQUANT_URL = LLAMA3_2_1B_SPINQUANT; diff --git a/src/models/Classification.ts b/src/models/Classification.ts index bb2fd1e..4ece3f6 100644 --- a/src/models/Classification.ts +++ b/src/models/Classification.ts @@ -4,7 +4,7 @@ import { Classification } from '../native/RnExecutorchModules'; import { ETError, getError } from '../Error'; interface Props { - modulePath: string | number; + modelSource: string | number; } interface ClassificationModule { @@ -15,7 +15,7 @@ interface ClassificationModule { } export const useClassification = ({ - modulePath, + modelSource, }: Props): ClassificationModule => { const [error, setError] = useState(null); const [isModelReady, setIsModelReady] = useState(false); @@ -23,10 +23,10 @@ export const useClassification = ({ useEffect(() => { const loadModel = async () => { - let path = modulePath; + let path = modelSource; - if (typeof modulePath === 'number') { - path = Image.resolveAssetSource(modulePath).uri; + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; } try { @@ -40,7 +40,7 @@ export const useClassification = ({ }; loadModel(); - }, [modulePath]); + }, [modelSource]); const forward = async (input: string) => { if (!isModelReady) {