Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: wrong handle of assets files in release apk #41

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
}

private fun downloadModel(
url: URL, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
url: String, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
) {
Fetcher.downloadResource(reactApplicationContext,
client,
url,
resourceType,
false,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
Expand All @@ -38,7 +39,7 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
override fun loadModule(modelPath: String, promise: Promise) {
try {
downloadModel(
URL(modelPath), ResourceType.MODEL
modelPath, ResourceType.MODEL
) { path, error ->
if (error != null) {
promise.reject(error.message!!, "-1")
Expand Down
chmjkb marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package com.swmansion.rnexecutorch

import android.os.Build
import android.util.Log
import androidx.annotation.RequiresApi
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.Fetcher
Expand Down Expand Up @@ -47,12 +45,13 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
}

private fun downloadResource(
url: URL,
url: String,
resourceType: ResourceType,
callback: (path: String?, error: Exception?) -> Unit
isLargeFile: Boolean = false,
callback: (path: String?, error: Exception?) -> Unit,
) {
Fetcher.downloadResource(
reactApplicationContext, client, url, resourceType,
reactApplicationContext, client, url, resourceType, isLargeFile,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
Expand All @@ -71,7 +70,6 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
promise.resolve("Model loaded successfully")
}

@RequiresApi(Build.VERSION_CODES.TIRAMISU)
override fun loadLLM(
modelSource: String,
tokenizerSource: String,
Expand All @@ -85,14 +83,12 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
}

try {
val modelURL = URL(modelSource)
val tokenizerURL = URL(tokenizerSource)
this.conversationManager = ConversationManager(contextWindowLength.toInt(), systemPrompt)

isFetching = true

downloadResource(
tokenizerURL,
tokenizerSource,
ResourceType.TOKENIZER
) tokenizerDownload@{ tokenizerPath, error ->
if (error != null) {
Expand All @@ -101,7 +97,7 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
return@tokenizerDownload
}

downloadResource(modelURL, ResourceType.MODEL) modelDownload@{ modelPath, modelError ->
downloadResource(modelSource, ResourceType.MODEL, isLargeFile = true) modelDownload@{ modelPath, modelError ->
if (modelError != null) {
promise.reject(
"Download Error",
Expand All @@ -120,7 +116,6 @@ class RnExecutorchModule(reactContext: ReactApplicationContext) :
}
}

@RequiresApi(Build.VERSION_CODES.N)
NorbertKlockiewicz marked this conversation as resolved.
Show resolved Hide resolved
override fun runInference(
input: String,
promise: Promise
Expand Down
182 changes: 127 additions & 55 deletions android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt
NorbertKlockiewicz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ class Fetcher {
return file
}

private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
private fun getValidExtension(resourceType: ResourceType): String {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
"bin"
}

ResourceType.MODEL -> {
fileName.endsWith(".pte")
"pte"
}
}
}
Expand All @@ -47,17 +47,9 @@ class Fetcher {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
}

throw Exception("file_not_found")
} else {
return url.path.substringAfterLast('/')
}

return url.path.substringAfterLast('/')
}

private fun fetchModel(
Expand Down Expand Up @@ -132,48 +124,74 @@ class Fetcher {
return response
}

fun downloadResource(
private fun getIdOfResource(
context: Context,
client: OkHttpClient,
resourceName: String,
defType: String = "raw"
): Int {
return context.resources.getIdentifier(resourceName, defType, context.packageName)
}

private fun getResourceFromAssets(
context: Context,
url: String,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit
) {
if (!url.contains("://")) {
//The provided file is from react-native assets folder in release mode
val resId = getIdOfResource(context, url)
val resName = context.resources.getResourceEntryName(resId)
val fileExtension = getValidExtension(resourceType)
context.resources.openRawResource(resId).use { inputStream ->
val file = File(
context.filesDir,
"$resName.$fileExtension"
)
file.outputStream().use { outputStream ->
inputStream.copyTo(outputStream)
}
onComplete(file.absolutePath, null)
return
}
}
}

private fun getLocalFile(
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
onComplete: (String?, Exception?) -> Unit
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String
// The provided file is a local file, get rid of the file:// prefix and return path
if (url.protocol == "file") {
val localPath = url.path
if (getValidExtension(resourceType) != localPath.takeLast(3)) {
throw Exception("invalid_extension")
}

try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}
val file = File(localPath)
if (file.exists()) {
onComplete(localPath, null)
return
}

if (fileName.contains("/")) {
onComplete(fileName, null)
return
throw Exception("file_not_found")
}
}

if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}
private fun getRemoteFile(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
isLargeFile: Boolean,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener?
) {
val fileName = extractFileName(url)

val tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
if (getValidExtension(resourceType) != fileName.takeLast(3)) {
throw Exception("invalid_extension")
}

val modelsDirectory = File(context.filesDir, "models").apply {
Expand All @@ -188,29 +206,83 @@ class Fetcher {
return
}

if (resourceType == ResourceType.TOKENIZER) {
// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
if (isUrlPointingToHfRepo(url)) {
val configUrl = resolveConfigUrlFromModelUrl(url)
sendRequestToUrl(configUrl, "HEAD", null, client)
}

if (!isLargeFile) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()

if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
throw Exception("download_error")
}

validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}

// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
if (isUrlPointingToHfRepo(url)) {
val configUrl = resolveConfigUrlFromModelUrl(url)
sendRequestToUrl(configUrl, "HEAD", null, client)
val tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
}

fetchModel(tempFile, validFile, client, url, onComplete, listener)
}

fun downloadResource(
context: Context,
client: OkHttpClient,
url: String,
resourceType: ResourceType,
isLargeFile: Boolean,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
/*
Fetching model and tokenizer file
1. Check if the provided file is a bundled local file
a. Check if it exists
b. Check if it has valid extension
c. Copy the file and return the path
2. Check if the provided file is a path to a local file
a. Check if it exists
b. Check if it has valid extension
c. Return the path
3. The provided file is a remote file
a. Check if it has valid extension
b. Check if it's a large file
i. Create temporary file to store it at download time
ii. Move it to the models directory and return the path
c. If it's not a large file download it and return the path
*/

try {
getResourceFromAssets(context, url, resourceType, onComplete)

val resUrl = URL(url)
/*
The provided file is either a remote file or a local file
- local file: file:///path/to/file
- remote file: https://path/to/file || http://10.0.2.2:8080/path/to/file
*/
getLocalFile(resUrl, resourceType, onComplete)

/*
The provided file is a remote file, if it's a large file
create temporary file to store it at download time and later
move it to the models directory
*/
getRemoteFile(context, client, resUrl, resourceType, isLargeFile, onComplete, listener)
} catch (e: Exception) {
onComplete(null, e)
return
}
}
}
}
4 changes: 4 additions & 0 deletions docs/docs/fundamentals/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ This allows us to use binaries, such as exported models or tokenizers for LLMs.
When using Expo, please note that you need to use a custom development build of your app, not the standard Expo Go app. This is because we rely on native modules, which Expo Go doesn’t support.
:::

:::info[Info]
Because we are using ExecuTorch under the hood, you won't be able to build ios app for release with simulator selected as the target device. Make sure to test release builds on real devices.
:::

Running the app with the library:
```bash
yarn run expo:<ios | android> -d
Expand Down
Loading