diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt b/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt index 0411598..54a34d5 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt @@ -35,7 +35,7 @@ class ObjectDetection(reactContext: ReactApplicationContext) : ssdLiteLarge.loadModel(modelSource) promise.resolve(0) } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelPath.toString()) + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) } } diff --git a/docs/docs/guides/running-llms.md b/docs/docs/guides/running-llms.md index a2df135..eeb1004 100644 --- a/docs/docs/guides/running-llms.md +++ b/docs/docs/guides/running-llms.md @@ -23,7 +23,7 @@ const llama = useLLM({ }); ``` -The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isModelReady` property, and if anything goes wrong, the `llama.error` property will contain the error message. +The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isReady` property, and if anything goes wrong, the `llama.error` property will contain the error message. :::danger[Danger] Lower-end devices might not be able to fit LLMs into memory. We recommend using quantized models to reduce the memory footprint. @@ -50,9 +50,9 @@ Given computational constraints, our architecture is designed to support only on | `generate` | `(input: string) => Promise` | Function to start generating a response with the given input string. | | `response` | `string` | State of the generated response. This field is updated with each token generated by the model | | `error` | string | null | Contains the error message if the model failed to load | -| `isModelGenerating` | `boolean` | Indicates whether the model is currently generating a response | +| `isGenerating` | `boolean` | Indicates whether the model is currently generating a response | | `interrupt` | `() => void` | Function to interrupt the current inference | -| `isModelReady` | `boolean` | Indicates whether the model is ready | +| `isReady` | `boolean` | Indicates whether the model is ready | | `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. | ### Loading models diff --git a/examples/computer-vision/screens/ClassificationScreen.tsx b/examples/computer-vision/screens/ClassificationScreen.tsx index 3434c0d..5c3b0c3 100644 --- a/examples/computer-vision/screens/ClassificationScreen.tsx +++ b/examples/computer-vision/screens/ClassificationScreen.tsx @@ -44,12 +44,9 @@ export const ClassificationScreen = ({ } }; - if (!model.isModelReady) { + if (!model.isReady) { return ( - + ); } diff --git a/examples/computer-vision/screens/ObjectDetectionScreen.tsx b/examples/computer-vision/screens/ObjectDetectionScreen.tsx index 2e6acaf..280e3c5 100644 --- a/examples/computer-vision/screens/ObjectDetectionScreen.tsx +++ b/examples/computer-vision/screens/ObjectDetectionScreen.tsx @@ -52,10 +52,10 @@ export const ObjectDetectionScreen = ({ } }; - if (!ssdLite.isModelReady) { + if (!ssdLite.isReady) { return ( ); diff --git a/examples/computer-vision/screens/StyleTransferScreen.tsx b/examples/computer-vision/screens/StyleTransferScreen.tsx index 633e329..8ccd0a8 100644 --- a/examples/computer-vision/screens/StyleTransferScreen.tsx +++ b/examples/computer-vision/screens/StyleTransferScreen.tsx @@ -37,12 +37,9 @@ export const StyleTransferScreen = ({ } }; - if (!model.isModelReady) { + if (!model.isReady) { return ( - + ); } diff --git a/examples/computer-vision/yarn.lock b/examples/computer-vision/yarn.lock index 5249305..0b8a954 100644 --- a/examples/computer-vision/yarn.lock +++ b/examples/computer-vision/yarn.lock @@ -3352,7 +3352,7 @@ __metadata: metro-config: ^0.81.0 react: 18.3.1 react-native: 0.76.3 - react-native-executorch: ../../react-native-executorch-0.1.100.tgz + react-native-executorch: ^0.1.3 react-native-image-picker: ^7.2.2 react-native-loading-spinner-overlay: ^3.0.1 react-native-reanimated: ^3.16.3 @@ -6799,13 +6799,13 @@ __metadata: languageName: node linkType: hard -"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.": - version: 0.1.100 - resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A." +"react-native-executorch@npm:^0.1.3": + version: 0.1.3 + resolution: "react-native-executorch@npm:0.1.3" peerDependencies: react: "*" react-native: "*" - checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03 + checksum: b49f8ca9ae8a7de4a7f2263887537626859507c7d60af47360515b405c7778b48c4c067074e7ce7857782a6737cf47cf5dadada03ae9319a6bf577f8490f431d languageName: node linkType: hard diff --git a/examples/llama/components/Messages.tsx b/examples/llama/components/Messages.tsx index 0d86856..f66d5aa 100644 --- a/examples/llama/components/Messages.tsx +++ b/examples/llama/components/Messages.tsx @@ -9,13 +9,13 @@ import MessageItem from './MessageItem'; interface MessagesComponentProps { chatHistory: Array; llmResponse: string; - isModelGenerating: boolean; + isGenerating: boolean; } export default function Messages({ chatHistory, llmResponse, - isModelGenerating, + isGenerating, }: MessagesComponentProps) { const scrollViewRef = useRef(null); @@ -29,7 +29,7 @@ export default function Messages({ {chatHistory.map((message, index) => ( ))} - {isModelGenerating && ( + {isGenerating && ( diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx index 10f18ab..4d2f707 100644 --- a/examples/llama/screens/ChatScreen.tsx +++ b/examples/llama/screens/ChatScreen.tsx @@ -31,10 +31,10 @@ export default function ChatScreen() { }); const textInputRef = useRef(null); useEffect(() => { - if (llama.response && !llama.isModelGenerating) { + if (llama.response && !llama.isGenerating) { appendToMessageHistory(llama.response, 'ai'); } - }, [llama.response, llama.isModelGenerating]); + }, [llama.response, llama.isGenerating]); const appendToMessageHistory = (input: string, role: SenderType) => { setChatHistory((prevHistory) => [ @@ -54,9 +54,9 @@ export default function ChatScreen() { } }; - return !llama.isModelReady ? ( + return !llama.isReady ? ( ) : ( @@ -76,7 +76,7 @@ export default function ChatScreen() { ) : ( @@ -108,13 +108,13 @@ export default function ChatScreen() { - !llama.isModelGenerating && (await sendMessage()) + !llama.isGenerating && (await sendMessage()) } > )} - {llama.isModelGenerating && ( + {llama.isGenerating && ( { - if (arr instanceof Int8Array) return 0; - if (arr instanceof Int32Array) return 1; - if (arr instanceof BigInt64Array) return 2; - if (arr instanceof Float32Array) return 3; - if (arr instanceof Float64Array) return 4; - - return -1; -}; +import { useState } from 'react'; +import { _ETModule } from './native/RnExecutorchModules'; +import { getError } from './Error'; +import { ExecutorchModule } from './types/common'; +import { useModule } from './useModule'; interface Props { modelSource: string | number; @@ -21,54 +11,20 @@ interface Props { export const useExecutorchModule = ({ modelSource, }: Props): ExecutorchModule => { - const [error, setError] = useState(null); - const [isModelLoading, setIsModelLoading] = useState(true); - const [isModelGenerating, setIsModelGenerating] = useState(false); - - useEffect(() => { - const loadModel = async () => { - let path = modelSource; - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - setIsModelLoading(true); - await ETModule.loadModule(path); - setIsModelLoading(false); - } catch (e: unknown) { - setError(getError(e)); - setIsModelLoading(false); - } - }; - loadModel(); - }, [modelSource]); - - const forward = async (input: ETInput, shape: number[]) => { - if (isModelLoading) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - - const inputType = getTypeIdentifier(input); - if (inputType === -1) { - throw new Error(getError(ETError.InvalidArgument)); - } - - try { - const numberArray = [...input]; - setIsModelGenerating(true); - const output = await ETModule.forward(numberArray, shape, inputType); - setIsModelGenerating(false); - return output; - } catch (e) { - setIsModelGenerating(false); - throw new Error(getError(e)); - } - }; + const [module] = useState(() => new _ETModule()); + const { + error, + isReady, + isGenerating, + forwardETInput: forward, + } = useModule({ + modelSource, + module, + }); const loadMethod = async (methodName: string) => { try { - await ETModule.loadMethod(methodName); + await module.loadMethod(methodName); } catch (e) { throw new Error(getError(e)); } @@ -79,11 +35,11 @@ export const useExecutorchModule = ({ }; return { - error: error, - isModelLoading: isModelLoading, - isModelGenerating: isModelGenerating, - forward: forward, - loadMethod: loadMethod, - loadForward: loadForward, + error, + isReady, + isGenerating, + forward, + loadMethod, + loadForward, }; }; diff --git a/src/LLM.ts b/src/LLM.ts index cc33271..3cd67ff 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -1,6 +1,6 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import { EventSubscription, Image } from 'react-native'; -import { ResourceSource, Model } from './types'; +import { ResourceSource, Model } from './types/common'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, DEFAULT_SYSTEM_PROMPT, @@ -24,8 +24,8 @@ export const useLLM = ({ contextWindowLength?: number; }): Model => { const [error, setError] = useState(null); - const [isModelReady, setIsModelReady] = useState(false); - const [isModelGenerating, setIsModelGenerating] = useState(false); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); const [response, setResponse] = useState(''); const [downloadProgress, setDownloadProgress] = useState(0); const downloadProgressListener = useRef(null); @@ -65,7 +65,7 @@ export const useLLM = ({ contextWindowLength ); - setIsModelReady(true); + setIsReady(true); tokenGeneratedListener.current = LLM.onToken( (data: string | undefined) => { @@ -75,13 +75,13 @@ export const useLLM = ({ if (data !== EOT_TOKEN) { setResponse((prevResponse) => prevResponse + data); } else { - setIsModelGenerating(false); + setIsGenerating(false); } } ); } catch (err) { const message = (err as Error).message; - setIsModelReady(false); + setIsReady(false); setError(message); } }; @@ -99,7 +99,7 @@ export const useLLM = ({ const generate = useCallback( async (input: string): Promise => { - if (!isModelReady) { + if (!isReady) { throw new Error('Model is still loading'); } if (error) { @@ -108,21 +108,23 @@ export const useLLM = ({ try { setResponse(''); - setIsModelGenerating(true); + setIsGenerating(true); await LLM.runInference(input); } catch (err) { - setIsModelGenerating(false); + setIsGenerating(false); throw new Error((err as Error).message); } }, - [isModelReady, error] + [isReady, error] ); return { generate, error, - isModelReady, - isModelGenerating, + isReady, + isGenerating, + isModelReady: isReady, + isModelGenerating: isGenerating, response, downloadProgress, interrupt, diff --git a/src/StyleTransfer.ts b/src/StyleTransfer.ts deleted file mode 100644 index 16fe521..0000000 --- a/src/StyleTransfer.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { useEffect, useState } from 'react'; -import { Image } from 'react-native'; -import { StyleTransfer } from './native/RnExecutorchModules'; -import { ETError, getError } from './Error'; - -interface Props { - modelSource: string | number; -} - -interface StyleTransferModule { - error: string | null; - isModelReady: boolean; - isModelGenerating: boolean; - forward: (input: string) => Promise; -} - -export const useStyleTransfer = ({ - modelSource, -}: Props): StyleTransferModule => { - const [error, setError] = useState(null); - const [isModelReady, setIsModelReady] = useState(false); - const [isModelGenerating, setIsModelGenerating] = useState(false); - - useEffect(() => { - const loadModel = async () => { - let path = modelSource; - - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - setIsModelReady(false); - await StyleTransfer.loadModule(path); - setIsModelReady(true); - } catch (e) { - setError(getError(e)); - } - }; - - loadModel(); - }, [modelSource]); - - const forward = async (input: string) => { - if (!isModelReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - if (isModelGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - - try { - setIsModelGenerating(true); - const output = await StyleTransfer.forward(input); - return output; - } catch (e) { - throw new Error(getError(e)); - } finally { - setIsModelGenerating(false); - } - }; - - return { error, isModelReady, isModelGenerating, forward }; -}; diff --git a/src/index.tsx b/src/index.tsx index a8d0404..74cfd13 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,7 +1,7 @@ export * from './ETModule'; export * from './LLM'; -export * from './StyleTransfer'; -export * from './models/Classification'; export * from './constants/modelUrls'; -export * from './models/object_detection/ObjectDetection'; -export * from './models/object_detection/types'; +export * from './models/Classification'; +export * from './models/ObjectDetection'; +export * from './models/StyleTransfer'; +export * from './types/object_detection'; diff --git a/src/models/Classification.ts b/src/models/Classification.ts index 4ece3f6..6bd8dfb 100644 --- a/src/models/Classification.ts +++ b/src/models/Classification.ts @@ -1,7 +1,6 @@ -import { useEffect, useState } from 'react'; -import { Image } from 'react-native'; -import { Classification } from '../native/RnExecutorchModules'; -import { ETError, getError } from '../Error'; +import { useState } from 'react'; +import { _ClassificationModule } from '../native/RnExecutorchModules'; +import { useModule } from '../useModule'; interface Props { modelSource: string | number; @@ -9,62 +8,24 @@ interface Props { interface ClassificationModule { error: string | null; - isModelReady: boolean; - isModelGenerating: boolean; + isReady: boolean; + isGenerating: boolean; forward: (input: string) => Promise<{ [category: string]: number }>; } export const useClassification = ({ modelSource, }: Props): ClassificationModule => { - const [error, setError] = useState(null); - const [isModelReady, setIsModelReady] = useState(false); - const [isModelGenerating, setIsModelGenerating] = useState(false); - - useEffect(() => { - const loadModel = async () => { - let path = modelSource; - - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - setIsModelReady(false); - await Classification.loadModule(path); - } catch (e) { - setError(getError(e)); - } finally { - setIsModelReady(true); - } - }; - - loadModel(); - }, [modelSource]); - - const forward = async (input: string) => { - if (!isModelReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - - if (error) { - throw new Error(error); - } - - if (isModelGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - - try { - setIsModelGenerating(true); - const output = await Classification.forward(input); - setIsModelGenerating(false); - return output; - } catch (e) { - setIsModelGenerating(false); - throw new Error(getError(e)); - } - }; - - return { error, isModelReady, isModelGenerating, forward }; + const [module, _] = useState(() => new _ClassificationModule()); + const { + error, + isReady, + isGenerating, + forwardImage: forward, + } = useModule({ + modelSource, + module, + }); + + return { error, isReady, isGenerating, forward }; }; diff --git a/src/models/ObjectDetection.ts b/src/models/ObjectDetection.ts new file mode 100644 index 0000000..fda2fd0 --- /dev/null +++ b/src/models/ObjectDetection.ts @@ -0,0 +1,32 @@ +import { useState } from 'react'; +import { _ObjectDetectionModule } from '../native/RnExecutorchModules'; +import { useModule } from '../useModule'; +import { Detection } from '../types/object_detection'; + +interface Props { + modelSource: string | number; +} + +interface ObjectDetectionModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; +} + +export const useObjectDetection = ({ + modelSource, +}: Props): ObjectDetectionModule => { + const [module, _] = useState(() => new _ObjectDetectionModule()); + const { + error, + isReady, + isGenerating, + forwardImage: forward, + } = useModule({ + modelSource, + module, + }); + + return { error, isReady, isGenerating, forward }; +}; diff --git a/src/models/StyleTransfer.ts b/src/models/StyleTransfer.ts new file mode 100644 index 0000000..215f5ae --- /dev/null +++ b/src/models/StyleTransfer.ts @@ -0,0 +1,31 @@ +import { useState } from 'react'; +import { _StyleTransferModule } from '../native/RnExecutorchModules'; +import { useModule } from '../useModule'; + +interface Props { + modelSource: string | number; +} + +interface StyleTransferModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; +} + +export const useStyleTransfer = ({ + modelSource, +}: Props): StyleTransferModule => { + const [module, _] = useState(() => new _StyleTransferModule()); + const { + error, + isReady, + isGenerating, + forwardImage: forward, + } = useModule({ + modelSource, + module, + }); + + return { error, isReady, isGenerating, forward }; +}; diff --git a/src/models/object_detection/ObjectDetection.ts b/src/models/object_detection/ObjectDetection.ts deleted file mode 100644 index 8fd2a8e..0000000 --- a/src/models/object_detection/ObjectDetection.ts +++ /dev/null @@ -1,63 +0,0 @@ -import { useEffect, useState } from 'react'; -import { Image } from 'react-native'; -import { ETError, getError } from '../../Error'; -import { ObjectDetection } from '../../native/RnExecutorchModules'; -import { Detection } from './types'; - -interface Props { - modelSource: string | number; -} - -interface ObjectDetectionModule { - error: string | null; - isModelReady: boolean; - isModelGenerating: boolean; - forward: (input: string) => Promise; -} - -export const useObjectDetection = ({ - modelSource, -}: Props): ObjectDetectionModule => { - const [error, setError] = useState(null); - const [isModelReady, setIsModelReady] = useState(false); - const [isModelGenerating, setIsModelGenerating] = useState(false); - - useEffect(() => { - const loadModel = async () => { - let path = modelSource; - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - setIsModelReady(false); - await ObjectDetection.loadModule(path); - setIsModelReady(true); - } catch (e) { - setError(getError(e)); - } - }; - - loadModel(); - }, [modelSource]); - - const forward = async (input: string) => { - if (!isModelReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - if (isModelGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - try { - setIsModelGenerating(true); - const output = await ObjectDetection.forward(input); - return output; - } catch (e) { - throw new Error(getError(e)); - } finally { - setIsModelGenerating(false); - } - }; - - return { error, isModelReady, isModelGenerating, forward }; -}; diff --git a/src/native/NativeObjectDetection.ts b/src/native/NativeObjectDetection.ts index 8b70f3d..1bb52b6 100644 --- a/src/native/NativeObjectDetection.ts +++ b/src/native/NativeObjectDetection.ts @@ -1,6 +1,6 @@ import type { TurboModule } from 'react-native'; import { TurboModuleRegistry } from 'react-native'; -import { Detection } from '../models/object_detection/types'; +import { Detection } from '../types/object_detection'; export interface Spec extends TurboModule { loadModule(modelSource: string): Promise; diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 72170e3..8a80b59 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -71,4 +71,57 @@ const StyleTransfer = StyleTransferSpec } ); -export { LLM, ETModule, Classification, ObjectDetection, StyleTransfer }; +class _ObjectDetectionModule { + async forward(input: string) { + return await ObjectDetection.forward(input); + } + async loadModule(modelSource: string | number) { + return await ObjectDetection.loadModule(modelSource); + } +} + +class _StyleTransferModule { + async forward(input: string) { + return await StyleTransfer.forward(input); + } + async loadModule(modelSource: string | number) { + return await StyleTransfer.loadModule(modelSource); + } +} + +class _ClassificationModule { + async forward(input: string) { + return await Classification.forward(input); + } + async loadModule(modelSource: string | number) { + return await Classification.loadModule(modelSource); + } +} + +class _ETModule { + async forward( + input: number[], + shape: number[], + inputType: number + ): Promise { + return await ETModule.forward(input, shape, inputType); + } + async loadModule(modelSource: string) { + return await ETModule.loadModule(modelSource); + } + async loadMethod(methodName: string): Promise { + return await ETModule.loadMethod(methodName); + } +} + +export { + LLM, + ETModule, + Classification, + ObjectDetection, + StyleTransfer, + _ETModule, + _ClassificationModule, + _StyleTransferModule, + _ObjectDetectionModule, +}; diff --git a/src/types.ts b/src/types/common.ts similarity index 63% rename from src/types.ts rename to src/types/common.ts index 7d32ae8..f12643d 100644 --- a/src/types.ts +++ b/src/types/common.ts @@ -1,3 +1,10 @@ +import { + _ClassificationModule, + _StyleTransferModule, + _ObjectDetectionModule, + ETModule, +} from '../native/RnExecutorchModules'; + export type ResourceSource = string | number; export interface Model { @@ -6,7 +13,9 @@ export interface Model { downloadProgress: number; error: string | null; isModelGenerating: boolean; + isGenerating: boolean; isModelReady: boolean; + isReady: boolean; interrupt: () => void; } @@ -19,9 +28,15 @@ export type ETInput = export interface ExecutorchModule { error: string | null; - isModelLoading: boolean; - isModelGenerating: boolean; + isReady: boolean; + isGenerating: boolean; forward: (input: ETInput, shape: number[]) => Promise; loadMethod: (methodName: string) => Promise; loadForward: () => Promise; } + +export type module = + | _ClassificationModule + | _StyleTransferModule + | _ObjectDetectionModule + | typeof ETModule; diff --git a/src/models/object_detection/types.ts b/src/types/object_detection.ts similarity index 100% rename from src/models/object_detection/types.ts rename to src/types/object_detection.ts diff --git a/src/useModule.ts b/src/useModule.ts new file mode 100644 index 0000000..66c2fd4 --- /dev/null +++ b/src/useModule.ts @@ -0,0 +1,100 @@ +import { useEffect, useState } from 'react'; +import { Image } from 'react-native'; +import { ETError, getError } from './Error'; +import { ETInput, module } from './types/common'; + +const getTypeIdentifier = (arr: ETInput): number => { + if (arr instanceof Int8Array) return 0; + if (arr instanceof Int32Array) return 1; + if (arr instanceof BigInt64Array) return 2; + if (arr instanceof Float32Array) return 3; + if (arr instanceof Float64Array) return 4; + + return -1; +}; + +interface Props { + modelSource: string | number; + module: module; +} + +interface _Module { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forwardETInput: (input: ETInput, shape: number[]) => Promise; + forwardImage: (input: string) => Promise; +} + +export const useModule = ({ modelSource, module }: Props): _Module => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + + useEffect(() => { + const loadModel = async () => { + if (!modelSource) return; + let path = modelSource; + + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; + } + + try { + setIsReady(false); + await module.loadModule(path); + setIsReady(true); + } catch (e) { + setError(getError(e)); + } + }; + + loadModel(); + }, [modelSource, module]); + + const forwardImage = async (input: string) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + setIsGenerating(true); + const output = await module.forward(input); + return output; + } catch (e) { + throw new Error(getError(e)); + } finally { + setIsGenerating(false); + } + }; + + const forwardETInput = async (input: ETInput, shape: number[]) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + const inputType = getTypeIdentifier(input); + if (inputType === -1) { + throw new Error(getError(ETError.InvalidArgument)); + } + + try { + const numberArray = [...input]; + setIsGenerating(true); + const output = await module.forward(numberArray, shape, inputType); + setIsGenerating(false); + return output; + } catch (e) { + setIsGenerating(false); + throw new Error(getError(e)); + } + }; + + return { error, isReady, isGenerating, forwardETInput, forwardImage }; +};