From edcc9580819fed1127a003b436dde3fb4d1fa4c1 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 19 Dec 2024 13:42:31 +0100 Subject: [PATCH 1/3] fix: issue with isModelGenerating when switching between multiple models --- .../java/com/swmansion/rnexecutorch/LLM.kt | 9 +++---- ios/RnExecutorch/LLM.mm | 25 ++++++++----------- ios/RnExecutorch/utils/LargeFileFetcher.mm | 4 +-- src/LLM.ts | 11 +++----- src/native/NativeLLM.ts | 1 - 5 files changed, 19 insertions(+), 31 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index e12f027..1ed293f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -64,6 +64,7 @@ class LLM(reactContext: ReactApplicationContext) : private fun initializeLlamaModule(modelPath: String, tokenizerPath: String, promise: Promise) { llamaModule = LlamaModule(1, modelPath, tokenizerPath, 0.7f) isFetching = false + this.tempLlamaResponse.clear() promise.resolve("Model loaded successfully") } @@ -74,8 +75,8 @@ class LLM(reactContext: ReactApplicationContext) : contextWindowLength: Double, promise: Promise ) { - if (llamaModule != null || isFetching) { - promise.reject("Model already loaded", "Model is already loaded or fetching") + if (isFetching) { + promise.reject("Model already loaded", "Model is already fetching") return } @@ -148,10 +149,6 @@ class LLM(reactContext: ReactApplicationContext) : llamaModule!!.stop() } - override fun deleteModule() { - llamaModule = null - } - companion object { const val NAME = "LLM" } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 2f270c6..3837b29 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -28,7 +28,7 @@ - (instancetype)init { isFetching = NO; tempLlamaResponse = [[NSMutableString alloc] init]; } - + return self; } @@ -38,7 +38,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { if ([token isEqualToString:prompt]) { return; } - + dispatch_async(dispatch_get_main_queue(), ^{ [self emitOnToken:token]; [self->tempLlamaResponse appendString:token]; @@ -54,8 +54,8 @@ - (void)updateDownloadProgress:(NSNumber *)progress { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { NSURL *modelURL = [NSURL URLWithString:modelSource]; NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource]; - - if(self->runner || isFetching){ + + if(isFetching){ reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil); return; } @@ -78,10 +78,11 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou modelFetcher.onFinish = ^(NSString *modelFilePath) { self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath]; - NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); + NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength); self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt]; self->isFetching = NO; + self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); return; }; @@ -94,23 +95,23 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { [conversationManager addResponse:input senderRole:ChatRole::USER]; NSString *prompt = [conversationManager getConversation]; - + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ NSError *error = nil; [self->runner generate:prompt withTokenCallback:^(NSString *token) { - [self onResult:token prompt:prompt]; + [self onResult:token prompt:prompt]; } error:&error]; - + // make sure to add eot token once generation is done if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) { [self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt]; } - + if (self->tempLlamaResponse) { [self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT]; self->tempLlamaResponse = [NSMutableString string]; } - + if (error) { reject(@"error_in_generation", error.localizedDescription, nil); return; @@ -125,10 +126,6 @@ -(void)interrupt { [self->runner stop]; } --(void)deleteModule { - self->runner = nil; -} - - (std::shared_ptr)getTurboModule:(const facebook::react::ObjCTurboModule::InitParams &)params { return std::make_shared(params); diff --git a/ios/RnExecutorch/utils/LargeFileFetcher.mm b/ios/RnExecutorch/utils/LargeFileFetcher.mm index 6ae58db..48cc39b 100644 --- a/ios/RnExecutorch/utils/LargeFileFetcher.mm +++ b/ios/RnExecutorch/utils/LargeFileFetcher.mm @@ -12,7 +12,7 @@ @implementation LargeFileFetcher { - (instancetype)init { self = [super init]; if (self) { - NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:@"com.swmansion.rnexecutorch"]; + NSURLSessionConfiguration *configuration = [NSURLSessionConfiguration backgroundSessionConfigurationWithIdentifier:[NSString stringWithFormat:@"com.swmansion.rnexecutorch.%@", [[NSUUID UUID] UUIDString]]]; _session = [NSURLSession sessionWithConfiguration:configuration delegate:self delegateQueue:nil]; } return self; @@ -111,7 +111,7 @@ - (void)startDownloadingFileFromURL:(NSURL *)url { - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTask *)downloadTask didFinishDownloadingToURL:(NSURL *)location { NSFileManager *fileManager = [NSFileManager defaultManager]; - + [fileManager removeItemAtPath:_destination error:nil]; NSError *error; diff --git a/src/LLM.ts b/src/LLM.ts index 3cd67ff..7065be3 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -30,15 +30,8 @@ export const useLLM = ({ const [downloadProgress, setDownloadProgress] = useState(0); const downloadProgressListener = useRef(null); const tokenGeneratedListener = useRef(null); - const initialized = useRef(false); useEffect(() => { - if (initialized.current) { - return; - } - - initialized.current = true; - const loadModel = async () => { try { let modelUrl = modelSource; @@ -57,6 +50,7 @@ export const useLLM = ({ } } ); + setIsReady(false); await LLM.loadLLM( modelUrl as string, @@ -83,6 +77,8 @@ export const useLLM = ({ const message = (err as Error).message; setIsReady(false); setError(message); + } finally { + setDownloadProgress(0); } }; @@ -93,7 +89,6 @@ export const useLLM = ({ downloadProgressListener.current = null; tokenGeneratedListener.current?.remove(); tokenGeneratedListener.current = null; - LLM.deleteModule(); }; }, [contextWindowLength, modelSource, systemPrompt, tokenizerSource]); diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index 23ee518..9cb2c95 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -10,7 +10,6 @@ export interface Spec extends TurboModule { contextWindowLength: number ): Promise; runInference(input: string): Promise; - deleteModule(): void; interrupt(): void; readonly onToken: EventEmitter; From abb4101e9d2dee367499a46b345c497f6cd390fb Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 20 Dec 2024 15:59:33 +0100 Subject: [PATCH 2/3] chore: change error names, undelete deleteModule --- android/src/main/java/com/swmansion/rnexecutorch/LLM.kt | 6 +++++- ios/RnExecutorch/LLM.mm | 7 +++++-- src/LLM.ts | 1 + src/native/NativeLLM.ts | 1 + 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index 1ed293f..393468b 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -76,7 +76,7 @@ class LLM(reactContext: ReactApplicationContext) : promise: Promise ) { if (isFetching) { - promise.reject("Model already loaded", "Model is already fetching") + promise.reject("Model is fetching", "Model is fetching") return } @@ -149,6 +149,10 @@ class LLM(reactContext: ReactApplicationContext) : llamaModule!!.stop() } + override fun deleteModule() { + llamaModule = null + } + companion object { const val NAME = "LLM" } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 3837b29..35d40fa 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -56,7 +56,7 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource]; if(isFetching){ - reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil); + reject(@"model_is_loaded", @"Model is fetching", nil); return; } @@ -121,11 +121,14 @@ - (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve }); } - -(void)interrupt { [self->runner stop]; } +-(void)deleteModule { + self->runner = nil; +} + - (std::shared_ptr)getTurboModule:(const facebook::react::ObjCTurboModule::InitParams &)params { return std::make_shared(params); diff --git a/src/LLM.ts b/src/LLM.ts index 7065be3..4219fdc 100644 --- a/src/LLM.ts +++ b/src/LLM.ts @@ -89,6 +89,7 @@ export const useLLM = ({ downloadProgressListener.current = null; tokenGeneratedListener.current?.remove(); tokenGeneratedListener.current = null; + LLM.deleteModule(); }; }, [contextWindowLength, modelSource, systemPrompt, tokenizerSource]); diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index 9cb2c95..35d2a42 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -11,6 +11,7 @@ export interface Spec extends TurboModule { ): Promise; runInference(input: string): Promise; interrupt(): void; + deleteModule(): void; readonly onToken: EventEmitter; readonly onDownloadProgress: EventEmitter; From d2a037e7effbd67a27aa042476a67304077301b0 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 20 Dec 2024 16:03:14 +0100 Subject: [PATCH 3/3] fix: typo --- ios/RnExecutorch/LLM.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 35d40fa..8b6957f 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -56,7 +56,7 @@ - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSou NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource]; if(isFetching){ - reject(@"model_is_loaded", @"Model is fetching", nil); + reject(@"model_is_fetching", @"Model is fetching", nil); return; }