Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring and small enhancements #80

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
1 change: 1 addition & 0 deletions .github/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git config core.hooksPath .github/hooks
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ obj/
obj.meta
bin/
bin.meta

*.api
*.api.meta
13 changes: 11 additions & 2 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@ public class LLMEditor : LLMClientEditor
public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript)
{
EditorGUILayout.BeginHorizontal();
if (GUILayout.Button("Download model", GUILayout.Width(buttonWidth)))

string[] options = new string[llmScript.modelOptions.Length];
for (int i = 0; i < llmScript.modelOptions.Length; i++)
{
llmScript.DownloadModel();
options[i] = llmScript.modelOptions[i].Item1;
}

int newIndex = EditorGUILayout.Popup("Model", llmScript.SelectedOption, options);
if (newIndex != llmScript.SelectedOption)
{
llmScript.DownloadModel(newIndex);
}

if (GUILayout.Button("Load model", GUILayout.Width(buttonWidth)))
{
EditorApplication.delayCall += () =>
Expand Down
29 changes: 18 additions & 11 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ public class LLM : LLMClient
[ModelAdvanced] public int contextSize = 512;
[ModelAdvanced] public int batchSize = 512;

[HideInInspector] public string modelUrl = "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true";
[HideInInspector] public readonly (string, string)[] modelOptions = new (string, string)[]{
("Download model", null),
("Phi 2 (small, best)", "https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf?download=true"),
("Mistral 7B Instruct v0.2 (medium, best)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true")
};
public int SelectedOption = 0;
private static readonly string serverZipUrl = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.6/llamafile-0.6.zip";
private static readonly string server = Path.Combine(GetAssetPath(Path.GetFileNameWithoutExtension(serverZipUrl)), "bin/llamafile");
private static readonly string server = Path.Combine(LLMUnitySetup.GetAssetPath(Path.GetFileNameWithoutExtension(serverZipUrl)), "bin/llamafile");
private static readonly string apeARMUrl = "https://cosmo.zip/pub/cosmos/bin/ape-arm64.elf";
private static readonly string apeARM = GetAssetPath("ape-arm64.elf");
private static readonly string apeARM = LLMUnitySetup.GetAssetPath("ape-arm64.elf");
private static readonly string apeX86_64Url = "https://cosmo.zip/pub/cosmos/bin/ape-x86_64.elf";
private static readonly string apeX86_64 = GetAssetPath("ape-x86_64.elf");
private static readonly string apeX86_64 = LLMUnitySetup.GetAssetPath("ape-x86_64.elf");

[HideInInspector] public static float binariesProgress = 1;
[HideInInspector] public float modelProgress = 1;
Expand Down Expand Up @@ -65,7 +70,7 @@ private static async Task SetupBinaries()
string serverZip = Path.Combine(Application.temporaryCachePath, "llamafile.zip");
await LLMUnitySetup.DownloadFile(serverZipUrl, serverZip, true, false, null, SetBinariesProgress);
binariesDone += 1;
LLMUnitySetup.ExtractZip(serverZip, GetAssetPath());
LLMUnitySetup.ExtractZip(serverZip, LLMUnitySetup.GetAssetPath());
File.Delete(serverZip);
binariesDone += 1;
}
Expand All @@ -77,12 +82,14 @@ public static void SetBinariesProgress(float progress)
binariesProgress = binariesDone / 4f + 1f / 4f * progress;
}

public void DownloadModel()
public void DownloadModel(int optionIndex)
{
// download default model and disable model editor properties until the model is set
modelProgress = 0;
SelectedOption = optionIndex;
string modelUrl = modelOptions[optionIndex].Item2;
string modelName = Path.GetFileName(modelUrl).Split("?")[0];
string modelPath = GetAssetPath(modelName);
string modelPath = LLMUnitySetup.GetAssetPath(modelName);
Task downloadTask = LLMUnitySetup.DownloadFile(modelUrl, modelPath, false, false, SetModel, SetModelProgress);
}

Expand All @@ -95,7 +102,7 @@ public async Task SetModel(string path)
{
// set the model and enable the model editor properties
modelCopyProgress = 0;
model = await LLMUnitySetup.AddAsset(path, GetAssetPath());
model = await LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath());
EditorUtility.SetDirty(this);
modelCopyProgress = 1;
}
Expand All @@ -104,7 +111,7 @@ public async Task SetLora(string path)
{
// set the lora and enable the model editor properties
modelCopyProgress = 0;
lora = await LLMUnitySetup.AddAsset(path, GetAssetPath());
lora = await LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath());
EditorUtility.SetDirty(this);
modelCopyProgress = 1;
}
Expand Down Expand Up @@ -232,13 +239,13 @@ private void StartLLMServer()

// Start the LLM server in a cross-platform way
if (model == "") throw new Exception("No model file provided!");
string modelPath = GetAssetPath(model);
string modelPath = LLMUnitySetup.GetAssetPath(model);
if (!File.Exists(modelPath)) throw new Exception($"File {modelPath} not found!");

string loraPath = "";
if (lora != "")
{
loraPath = GetAssetPath(lora);
loraPath = LLMUnitySetup.GetAssetPath(lora);
if (!File.Exists(loraPath)) throw new Exception($"File {loraPath} not found!");
}

Expand Down
11 changes: 2 additions & 9 deletions Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,6 @@ public async Task SetPrompt(string newPrompt, bool clearChat = true)
nKeep = -1;
await InitPrompt(clearChat);
}

protected static string GetAssetPath(string relPath = "")
{
// Path to store llm server binaries and models
return Path.Combine(Application.streamingAssetsPath, relPath).Replace('\\', '/');
}

private async Task InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
Expand All @@ -163,7 +156,7 @@ private void InitGrammar()
{
if (grammar != null && grammar != "")
{
grammarString = File.ReadAllText(GetAssetPath(grammar));
grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
}
}

Expand All @@ -176,7 +169,7 @@ private void SetNKeep(List<int> tokens)
#if UNITY_EDITOR
public async void SetGrammar(string path)
{
grammar = await LLMUnitySetup.AddAsset(path, GetAssetPath());
grammar = await LLMUnitySetup.AddAsset(path, LLMUnitySetup.GetAssetPath());
}
#endif

Expand Down
23 changes: 18 additions & 5 deletions Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ public static void makeExecutable(string path)
}
}

public static string GetAssetPath(string relPath = "")
{
// Path to store llm server binaries and models
return Path.Combine(Application.streamingAssetsPath, relPath).Replace('\\', '/');
}

#if UNITY_EDITOR
public class DownloadStatus
{
Expand All @@ -91,7 +97,9 @@ public void DownloadProgressChanged(object sender, DownloadProgressChangedEventA

public static async Task DownloadFile(
string fileUrl, string savePath, bool overwrite = false, bool executable = false,
TaskCallback<string> callback = null, Callback<float> progresscallback = null)
TaskCallback<string> callback = null, Callback<float> progresscallback = null,
bool async=true
)
{
// download a file to the specified path
if (File.Exists(savePath) && !overwrite)
Expand All @@ -106,18 +114,23 @@ public static async Task DownloadFile(
WebClient client = new WebClient();
DownloadStatus downloadStatus = new DownloadStatus(progresscallback);
client.DownloadProgressChanged += downloadStatus.DownloadProgressChanged;
await client.DownloadFileTaskAsync(fileUrl, tmpPath);
if (async)
{
await client.DownloadFileTaskAsync(fileUrl, tmpPath);
} else {
client.DownloadFile(fileUrl, tmpPath);
}
if (executable) makeExecutable(tmpPath);

AssetDatabase.StartAssetEditing();
Directory.CreateDirectory(Path.GetDirectoryName(savePath));
File.Move(tmpPath, savePath);
AssetDatabase.StopAssetEditing();
Debug.Log($"Download complete!");

progresscallback?.Invoke(1f);
callback?.Invoke(savePath);
}

progresscallback?.Invoke(1f);
callback?.Invoke(savePath);
}

public static async Task<string> AddAsset(string assetPath, string basePath)
Expand Down
8 changes: 0 additions & 8 deletions hooks.meta

This file was deleted.

7 changes: 0 additions & 7 deletions hooks/pre-commit.meta

This file was deleted.

1 change: 0 additions & 1 deletion setup.sh

This file was deleted.

7 changes: 0 additions & 7 deletions setup.sh.meta

This file was deleted.

Loading