Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore:refactor ts cv hooks #64

Merged
merged 10 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ObjectDetection(reactContext: ReactApplicationContext) :
ssdLiteLarge.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

Expand Down
6 changes: 3 additions & 3 deletions docs/docs/guides/running-llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const llama = useLLM({
});
```

The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isModelReady` property, and if anything goes wrong, the `llama.error` property will contain the error message.
The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isReady` property, and if anything goes wrong, the `llama.error` property will contain the error message.

:::danger[Danger]
Lower-end devices might not be able to fit LLMs into memory. We recommend using quantized models to reduce the memory footprint.
Expand All @@ -50,9 +50,9 @@ Given computational constraints, our architecture is designed to support only on
| `generate` | `(input: string) => Promise<void>` | Function to start generating a response with the given input string. |
| `response` | `string` | State of the generated response. This field is updated with each token generated by the model |
| `error` | <code>string &#124; null</code> | Contains the error message if the model failed to load |
| `isModelGenerating` | `boolean` | Indicates whether the model is currently generating a response |
| `isGenerating` | `boolean` | Indicates whether the model is currently generating a response |
| `interrupt` | `() => void` | Function to interrupt the current inference |
| `isModelReady` | `boolean` | Indicates whether the model is ready |
| `isReady` | `boolean` | Indicates whether the model is ready |
| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. |

### Loading models
Expand Down
7 changes: 2 additions & 5 deletions examples/computer-vision/screens/ClassificationScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ export const ClassificationScreen = ({
}
};

if (!model.isModelReady) {
if (!model.isReady) {
return (
<Spinner
visible={!model.isModelReady}
textContent={`Loading the model...`}
/>
<Spinner visible={!model.isReady} textContent={`Loading the model...`} />
);
}

Expand Down
4 changes: 2 additions & 2 deletions examples/computer-vision/screens/ObjectDetectionScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
}
};

if (!ssdLite.isModelReady) {
if (!ssdLite.isReady) {
return (
<Spinner
visible={!ssdLite.isModelReady}
visible={!ssdLite.isReady}
textContent={`Loading the model...`}
/>
);
Expand All @@ -76,7 +76,7 @@
/>
) : (
<Image
style={{ width: '100%', height: '100%' }}

Check warning on line 79 in examples/computer-vision/screens/ObjectDetectionScreen.tsx

View workflow job for this annotation

GitHub Actions / lint

Inline style: { width: '100%', height: '100%' }
resizeMode="contain"
source={require('../assets/icons/executorch_logo.png')}
/>
Expand Down
7 changes: 2 additions & 5 deletions examples/computer-vision/screens/StyleTransferScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,9 @@ export const StyleTransferScreen = ({
}
};

if (!model.isModelReady) {
if (!model.isReady) {
return (
<Spinner
visible={!model.isModelReady}
textContent={`Loading the model...`}
/>
<Spinner visible={!model.isReady} textContent={`Loading the model...`} />
);
}

Expand Down
10 changes: 5 additions & 5 deletions examples/computer-vision/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3352,7 +3352,7 @@ __metadata:
metro-config: ^0.81.0
react: 18.3.1
react-native: 0.76.3
react-native-executorch: ../../react-native-executorch-0.1.100.tgz
react-native-executorch: ^0.1.3
react-native-image-picker: ^7.2.2
react-native-loading-spinner-overlay: ^3.0.1
react-native-reanimated: ^3.16.3
Expand Down Expand Up @@ -6799,13 +6799,13 @@ __metadata:
languageName: node
linkType: hard

"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.":
version: 0.1.100
resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A."
"react-native-executorch@npm:^0.1.3":
version: 0.1.3
resolution: "react-native-executorch@npm:0.1.3"
peerDependencies:
react: "*"
react-native: "*"
checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03
checksum: b49f8ca9ae8a7de4a7f2263887537626859507c7d60af47360515b405c7778b48c4c067074e7ce7857782a6737cf47cf5dadada03ae9319a6bf577f8490f431d
languageName: node
linkType: hard

Expand Down
6 changes: 3 additions & 3 deletions examples/llama/components/Messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import MessageItem from './MessageItem';
interface MessagesComponentProps {
chatHistory: Array<MessageType>;
llmResponse: string;
isModelGenerating: boolean;
isGenerating: boolean;
}

export default function Messages({
chatHistory,
llmResponse,
isModelGenerating,
isGenerating,
}: MessagesComponentProps) {
const scrollViewRef = useRef<ScrollView>(null);

Expand All @@ -29,7 +29,7 @@ export default function Messages({
{chatHistory.map((message, index) => (
<MessageItem key={index} message={message} />
))}
{isModelGenerating && (
{isGenerating && (
<View style={styles.aiMessage}>
<View style={styles.aiMessageIconContainer}>
<LlamaIcon width={24} height={24} />
Expand Down
14 changes: 7 additions & 7 deletions examples/llama/screens/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
});
const textInputRef = useRef<TextInput>(null);
useEffect(() => {
if (llama.response && !llama.isModelGenerating) {
if (llama.response && !llama.isGenerating) {
appendToMessageHistory(llama.response, 'ai');
}
}, [llama.response, llama.isModelGenerating]);
}, [llama.response, llama.isGenerating]);

const appendToMessageHistory = (input: string, role: SenderType) => {
setChatHistory((prevHistory) => [
Expand All @@ -54,16 +54,16 @@
}
};

return !llama.isModelReady ? (
return !llama.isReady ? (
<Spinner
visible={!llama.isModelReady}
visible={!llama.isReady}
textContent={`Loading the model ${(llama.downloadProgress * 100).toFixed(0)} %`}
/>
) : (
<SafeAreaView style={styles.container}>
<TouchableWithoutFeedback onPress={Keyboard.dismiss}>
<KeyboardAvoidingView
style={{ flex: 1 }}

Check warning on line 66 in examples/llama/screens/ChatScreen.tsx

View workflow job for this annotation

GitHub Actions / lint

Inline style: { flex: 1 }
behavior={Platform.OS === 'ios' ? 'padding' : 'height'}
keyboardVerticalOffset={Platform.OS === 'android' ? 30 : 0}
>
Expand All @@ -76,7 +76,7 @@
<Messages
chatHistory={chatHistory}
llmResponse={llama.response}
isModelGenerating={llama.isModelGenerating}
isGenerating={llama.isGenerating}
/>
</View>
) : (
Expand Down Expand Up @@ -108,13 +108,13 @@
<TouchableOpacity
style={styles.sendChatTouchable}
onPress={async () =>
!llama.isModelGenerating && (await sendMessage())
!llama.isGenerating && (await sendMessage())
}
>
<SendIcon height={24} width={24} padding={4} margin={8} />
</TouchableOpacity>
)}
{llama.isModelGenerating && (
{llama.isGenerating && (
<TouchableOpacity
style={styles.sendChatTouchable}
onPress={llama.interrupt}
Expand Down
88 changes: 22 additions & 66 deletions src/ETModule.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
import { useEffect, useState } from 'react';
import { Image } from 'react-native';
import { ETModule } from './native/RnExecutorchModules';
import { ETError, getError } from './Error';
import { ETInput, ExecutorchModule } from './types';

const getTypeIdentifier = (arr: ETInput): number => {
if (arr instanceof Int8Array) return 0;
if (arr instanceof Int32Array) return 1;
if (arr instanceof BigInt64Array) return 2;
if (arr instanceof Float32Array) return 3;
if (arr instanceof Float64Array) return 4;

return -1;
};
import { useState } from 'react';
import { _ETModule } from './native/RnExecutorchModules';
import { getError } from './Error';
import { ExecutorchModule } from './types/common';
import { useModule } from './useModule';

interface Props {
modelSource: string | number;
Expand All @@ -21,54 +11,20 @@ interface Props {
export const useExecutorchModule = ({
modelSource,
}: Props): ExecutorchModule => {
const [error, setError] = useState<string | null>(null);
const [isModelLoading, setIsModelLoading] = useState(true);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modelSource;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
setIsModelLoading(true);
await ETModule.loadModule(path);
setIsModelLoading(false);
} catch (e: unknown) {
setError(getError(e));
setIsModelLoading(false);
}
};
loadModel();
}, [modelSource]);

const forward = async (input: ETInput, shape: number[]) => {
if (isModelLoading) {
throw new Error(getError(ETError.ModuleNotLoaded));
}

const inputType = getTypeIdentifier(input);
if (inputType === -1) {
throw new Error(getError(ETError.InvalidArgument));
}

try {
const numberArray = [...input];
setIsModelGenerating(true);
const output = await ETModule.forward(numberArray, shape, inputType);
setIsModelGenerating(false);
return output;
} catch (e) {
setIsModelGenerating(false);
throw new Error(getError(e));
}
};
const [module] = useState(() => new _ETModule());
const {
error,
isReady,
isGenerating,
forwardETInput: forward,
} = useModule({
modelSource,
module,
});

const loadMethod = async (methodName: string) => {
try {
await ETModule.loadMethod(methodName);
await module.loadMethod(methodName);
} catch (e) {
throw new Error(getError(e));
}
Expand All @@ -79,11 +35,11 @@ export const useExecutorchModule = ({
};

return {
error: error,
isModelLoading: isModelLoading,
isModelGenerating: isModelGenerating,
forward: forward,
loadMethod: loadMethod,
loadForward: loadForward,
error,
isReady,
isGenerating,
forward,
loadMethod,
loadForward,
};
};
26 changes: 14 additions & 12 deletions src/LLM.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { EventSubscription, Image } from 'react-native';
import { ResourceSource, Model } from './types';
import { ResourceSource, Model } from './types/common';
import {
DEFAULT_CONTEXT_WINDOW_LENGTH,
DEFAULT_SYSTEM_PROMPT,
Expand All @@ -24,8 +24,8 @@ export const useLLM = ({
contextWindowLength?: number;
}): Model => {
const [error, setError] = useState<string | null>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);
const [isReady, setIsReady] = useState(false);
const [isGenerating, setIsGenerating] = useState(false);
const [response, setResponse] = useState('');
const [downloadProgress, setDownloadProgress] = useState(0);
const downloadProgressListener = useRef<null | EventSubscription>(null);
Expand Down Expand Up @@ -65,7 +65,7 @@ export const useLLM = ({
contextWindowLength
);

setIsModelReady(true);
setIsReady(true);

tokenGeneratedListener.current = LLM.onToken(
(data: string | undefined) => {
Expand All @@ -75,13 +75,13 @@ export const useLLM = ({
if (data !== EOT_TOKEN) {
setResponse((prevResponse) => prevResponse + data);
} else {
setIsModelGenerating(false);
setIsGenerating(false);
}
}
);
} catch (err) {
const message = (err as Error).message;
setIsModelReady(false);
setIsReady(false);
setError(message);
}
};
Expand All @@ -99,7 +99,7 @@ export const useLLM = ({

const generate = useCallback(
async (input: string): Promise<void> => {
if (!isModelReady) {
if (!isReady) {
throw new Error('Model is still loading');
}
if (error) {
Expand All @@ -108,21 +108,23 @@ export const useLLM = ({

try {
setResponse('');
setIsModelGenerating(true);
setIsGenerating(true);
await LLM.runInference(input);
} catch (err) {
setIsModelGenerating(false);
setIsGenerating(false);
throw new Error((err as Error).message);
}
},
[isModelReady, error]
[isReady, error]
);

return {
generate,
error,
isModelReady,
isModelGenerating,
isReady,
isGenerating,
isModelReady: isReady,
isModelGenerating: isGenerating,
response,
downloadProgress,
interrupt,
Expand Down
Loading
Loading