Skip to content

Commit

Permalink
@jakmro/standardize naming (#62)
Browse files Browse the repository at this point in the history
## Description
Standardize naming and fix urls

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [x] iOS
- [x] Android

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [x] My changes generate no new warnings
  • Loading branch information
jakmro authored Dec 17, 2024
1 parent 1499364 commit d79ae01
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Classification(reactContext: ReactApplicationContext) :
classificationModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

Expand Down
6 changes: 3 additions & 3 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
return NAME
}

override fun loadModule(modelPath: String, promise: Promise) {
override fun loadModule(modelSource: String, promise: Promise) {
Fetcher.downloadModel(
reactApplicationContext,
modelPath,
modelSource,
) { path, error ->
if (error != null) {
promise.reject(error.message!!, ETError.InvalidModelPath.toString())
promise.reject(error.message!!, ETError.InvalidModelSource.toString())
return@downloadModel
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StyleTransfer(reactContext: ReactApplicationContext) :
styleTransferModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ enum class ETError(val code: Int) {
UndefinedError(0x65),
ModuleNotLoaded(0x66),
FileWriteFailed(0x67),
InvalidModelPath(0xff),
InvalidModelSource(0xff),

// System errors
Ok(0x00),
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/guides/running-llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ React Native ExecuTorch supports Llama 3.2 models, including quantized versions.
In order to load a model into the app, you need to run the following code:

```typescript
import { useLLM, LLAMA3_2_1B_URL } from 'react-native-executorch';
import { useLLM, LLAMA3_2_1B } from 'react-native-executorch';

const llama = useLLM({
modelSource: LLAMA3_2_1B_URL,
modelSource: LLAMA3_2_1B,
tokenizer: require('../assets/tokenizer.bin'),
contextWindowLength: 3,
});
Expand Down Expand Up @@ -91,7 +91,7 @@ In order to send a message to the model, one can use the following code:

```typescript
const llama = useLLM(
modelSource: LLAMA3_2_1B_URL,
modelSource: LLAMA3_2_1B,
tokenizer: require('../assets/tokenizer.bin'),
);

Expand Down
2 changes: 1 addition & 1 deletion examples/computer-vision/screens/ClassificationScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const ClassificationScreen = ({
);

const model = useClassification({
modulePath: EFFICIENTNET_V2_S,
modelSource: EFFICIENTNET_V2_S,
});

const handleCameraPress = async (isCamera: boolean) => {
Expand Down
4 changes: 2 additions & 2 deletions examples/computer-vision/screens/ObjectDetectionScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { getImage } from '../utils';
import {
Detection,
useObjectDetection,
SSDLITE_320_MOBILENET_V3_LARGE_URL,
SSDLITE_320_MOBILENET_V3_LARGE,
} from 'react-native-executorch';
import { View, StyleSheet, Image } from 'react-native';
import ImageWithBboxes from '../components/ImageWithBboxes';
Expand All @@ -24,7 +24,7 @@ export const ObjectDetectionScreen = ({
}>();

const ssdLite = useObjectDetection({
modelSource: SSDLITE_320_MOBILENET_V3_LARGE_URL,
modelSource: SSDLITE_320_MOBILENET_V3_LARGE,
});

const handleCameraPress = async (isCamera: boolean) => {
Expand Down
2 changes: 1 addition & 1 deletion examples/computer-vision/screens/StyleTransferScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const StyleTransferScreen = ({
setImageUri: (imageUri: string) => void;
}) => {
const model = useStyleTransfer({
modulePath: STYLE_TRANSFER_CANDY,
modelSource: STYLE_TRANSFER_CANDY,
});

const handleCameraPress = async (isCamera: boolean) => {
Expand Down
4 changes: 2 additions & 2 deletions examples/llama/screens/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { SafeAreaView } from 'react-native-safe-area-context';
import SWMIcon from '../assets/icons/swm_icon.svg';
import SendIcon from '../assets/icons/send_icon.svg';
import Spinner from 'react-native-loading-spinner-overlay';
import { LLAMA3_2_1B_QLORA_URL, useLLM } from 'react-native-executorch';
import { LLAMA3_2_1B_QLORA, useLLM } from 'react-native-executorch';
import PauseIcon from '../assets/icons/pause_icon.svg';
import ColorPalette from '../colors';
import Messages from '../components/Messages';
Expand All @@ -25,7 +25,7 @@ export default function ChatScreen() {
const [isTextInputFocused, setIsTextInputFocused] = useState(false);
const [userInput, setUserInput] = useState('');
const llama = useLLM({
modelSource: LLAMA3_2_1B_QLORA_URL,
modelSource: LLAMA3_2_1B_QLORA,
tokenizerSource: require('../assets/tokenizer.bin'),
contextWindowLength: 6,
});
Expand Down
2 changes: 1 addition & 1 deletion ios/RnExecutorch/models/BaseModel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ - (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber*
module = [[ETModel alloc] init];
[Fetcher fetchResource:modelURL resourceType:ResourceType::MODEL completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
completion(NO, @(InvalidModelPath));
completion(NO, @(InvalidModelSource));
return;
}
NSNumber *result = [self->module loadModel: filePath];
Expand Down
2 changes: 1 addition & 1 deletion ios/RnExecutorch/utils/ETError.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ typedef NS_ENUM(NSUInteger, ETError) {
UndefinedError = 0x65,
ModuleNotLoaded = 0x66,
FileWriteFailed = 0x67,
InvalidModelPath = 0xff,
InvalidModelSource = 0xff,

Ok = 0x00,
Internal = 0x01,
Expand Down
12 changes: 6 additions & 6 deletions src/ETModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ const getTypeIdentifier = (arr: ETInput): number => {
};

interface Props {
modulePath: string | number;
modelSource: string | number;
}

export const useExecutorchModule = ({
modulePath,
modelSource,
}: Props): ExecutorchModule => {
const [error, setError] = useState<string | null>(null);
const [isModelLoading, setIsModelLoading] = useState(true);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
let path = modelSource;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -42,7 +42,7 @@ export const useExecutorchModule = ({
}
};
loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: ETInput, shape: number[]) => {
if (isModelLoading) {
Expand Down
2 changes: 1 addition & 1 deletion src/Error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export enum ETError {
ModuleNotLoaded = 0x66,
FileWriteFailed = 0x67,
ModelGenerating = 0x68,
InvalidModelPath = 0xff,
InvalidModelSource = 0xff,

// ExecuTorch mapped errors
// Based on: https://github.com/pytorch/executorch/blob/main/runtime/core/error.h
Expand Down
12 changes: 6 additions & 6 deletions src/StyleTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { StyleTransfer } from './native/RnExecutorchModules';
import { ETError, getError } from './Error';

interface Props {
modulePath: string | number;
modelSource: string | number;
}

interface StyleTransferModule {
Expand All @@ -15,18 +15,18 @@ interface StyleTransferModule {
}

export const useStyleTransfer = ({
modulePath,
modelSource,
}: Props): StyleTransferModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
let path = modelSource;

if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -39,7 +39,7 @@ export const useStyleTransfer = ({
};

loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: string) => {
if (!isModelReady) {
Expand Down
40 changes: 23 additions & 17 deletions src/constants/modelUrls.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { Platform } from 'react-native';

// LLM's
export const LLAMA3_2_3B_URL =
export const LLAMA3_2_3B =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/original/llama3_2_3B_bf16.pte';
export const LLAMA3_2_3B_QLORA_URL =
export const LLAMA3_2_3B_QLORA =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/QLoRA/llama3_2-3B_qat_lora.pte';
export const LLAMA3_2_3B_SPINQUANT_URL =
export const LLAMA3_2_3B_SPINQUANT =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-3B/spinquant/llama3_2_3B_spinquant.pte';
export const LLAMA3_2_1B_URL =
export const LLAMA3_2_1B =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/llama3_2_bf16.pte';
export const LLAMA3_2_1B_QLORA_URL =
export const LLAMA3_2_1B_QLORA =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/QLoRA/llama3_2_qat_lora.pte';
export const LLAMA3_2_1B_SPINQUANT_URL =
export const LLAMA3_2_1B_SPINQUANT =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/spinquant/llama3_2_spinquant.pte';
export const LLAMA3_2_1B_TOKENIZER =
'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.1.0/llama-3.2-1B/original/tokenizer.bin';
Expand All @@ -24,26 +24,32 @@ export const EFFICIENTNET_V2_S =
? 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.2.0/coreml/efficientnet_v2_s_coreml_all.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.2.0/xnnpack/efficientnet_v2_s_xnnpack.pte';

// Object detection
export const SSDLITE_320_MOBILENET_V3_LARGE =
'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/v0.2.0/ssdlite320-mobilenetv3-large.pte';

// Style transfer
export const STYLE_TRANSFER_CANDY =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/coreml/style_transfer_candy_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.2.0/xnnpack/style_transfer_candy_xnnpack.pte';

export const STYLE_TRANSFER_MOSAIC =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/coreml/style_transfer_mosaic_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/main/xnnpack/style_transfer_mosaic_xnnpack.pte';

? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/coreml/style_transfer_mosaic_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.2.0/xnnpack/style_transfer_mosaic_xnnpack.pte';
export const STYLE_TRANSFER_RAIN_PRINCESS =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/coreml/style_transfer_rain_princess_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/main/xnnpack/style_transfer_rain_princess_xnnpack.pte';
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/coreml/style_transfer_rain_princess_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.2.0/xnnpack/style_transfer_rain_princess_xnnpack.pte';
export const STYLE_TRANSFER_UDNIE =
Platform.OS === 'ios'
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/coreml/style_transfer_udnie_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/xnnpack/style_transfer_udnie_xnnpack.pte';
? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/coreml/style_transfer_udnie_coreml.pte'
: 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/xnnpack/style_transfer_udnie_xnnpack.pte';

// Object detection
export const SSDLITE_320_MOBILENET_V3_LARGE_URL =
'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/main/ssdlite320-mobilenetv3-large.pte';
// Backward compatibility
export const LLAMA3_2_3B_URL = LLAMA3_2_3B;
export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA;
export const LLAMA3_2_3B_SPINQUANT_URL = LLAMA3_2_3B_SPINQUANT;
export const LLAMA3_2_1B_URL = LLAMA3_2_1B;
export const LLAMA3_2_1B_QLORA_URL = LLAMA3_2_1B_QLORA;
export const LLAMA3_2_1B_SPINQUANT_URL = LLAMA3_2_1B_SPINQUANT;
12 changes: 6 additions & 6 deletions src/models/Classification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Classification } from '../native/RnExecutorchModules';
import { ETError, getError } from '../Error';

interface Props {
modulePath: string | number;
modelSource: string | number;
}

interface ClassificationModule {
Expand All @@ -15,18 +15,18 @@ interface ClassificationModule {
}

export const useClassification = ({
modulePath,
modelSource,
}: Props): ClassificationModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
let path = modulePath;
let path = modelSource;

if (typeof modulePath === 'number') {
path = Image.resolveAssetSource(modulePath).uri;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}

try {
Expand All @@ -40,7 +40,7 @@ export const useClassification = ({
};

loadModel();
}, [modulePath]);
}, [modelSource]);

const forward = async (input: string) => {
if (!isModelReady) {
Expand Down

0 comments on commit d79ae01

Please sign in to comment.