From c702876fce5aafdbbdd15c04b1bffd808f9c369f Mon Sep 17 00:00:00 2001 From: Christian M Date: Sun, 7 Jan 2024 12:36:11 +0100 Subject: [PATCH] :sparkles: adds phi models --- README.md | 13 +++++++ src/llm/loader.rs | 1 + src/llm/models/mod.rs | 74 ++++++++++++++------------------------ src/llm/text_generation.rs | 4 +-- 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 850f6a2..12c7b66 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,18 @@ cargo build --release ### Running +Run the server + ```bash cargo run --release ``` +Run one of the models + +```bash +cargo run --release -- --model phi-v2 --prompt 'write me fibonacci in rust' +``` + ### Docker ```bash @@ -68,6 +76,7 @@ python test.py - [x] Zephyr - [x] OpenChat - [x] Starling +- [x] [Phi](https://huggingface.co/microsoft/phi-2) (Phi-1, Phi-1.5, Phi-2) - [ ] GPT-Neo - [ ] GPT-J - [ ] Llama @@ -77,6 +86,10 @@ python test.py ["lmz/candle-mistral"](https://huggingface.co/lmz/candle-mistral) +### Phi + +["microsoft/phi-2"](https://huggingface.co/microsoft/phi-2) + ## Performance The following table shows the performance metrics of the model on different systems: diff --git a/src/llm/loader.rs b/src/llm/loader.rs index dc40512..e35f452 100644 --- a/src/llm/loader.rs +++ b/src/llm/loader.rs @@ -126,6 +126,7 @@ pub fn create_model( | Models::L70bChat | Models::OpenChat35 | Models::Starling7bAlpha => 8, + Models::PhiHermes | Models::PhiV1 | Models::PhiV1_5 | Models::PhiV2 => 4, }; ModelWeights::from_ggml(content, default_gqa)? } diff --git a/src/llm/models/mod.rs b/src/llm/models/mod.rs index 129d769..279e85f 100644 --- a/src/llm/models/mod.rs +++ b/src/llm/models/mod.rs @@ -43,6 +43,15 @@ pub enum Models { Mixtral, #[serde(rename = "mixtral-instruct")] MixtralInstruct, + + #[serde(rename = "phi-hermes")] + PhiHermes, + #[serde(rename = "phi-v1")] + PhiV1, + #[serde(rename = "phi-v1.5")] + PhiV1_5, + #[serde(rename = "phi-v2")] + PhiV2, } #[derive(Deserialize)] @@ -62,19 +71,6 @@ impl FromStr for Models { impl Models { pub fn is_mistral(&self) -> bool { match self { - Self::L7b - | Self::L13b - | Self::L70b - | Self::L7bChat - | Self::L13bChat - | Self::L70bChat - | Self::L7bCode - | Self::L13bCode - | Self::L34bCode - | Self::Leo7b - | Self::Leo13b => false, - // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the - // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 | Self::Starling7bAlpha | Self::Zephyr7bAlpha @@ -83,52 +79,28 @@ impl Models { | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct => true, + _ => false, } } pub fn is_zephyr(&self) -> bool { match self { - Self::L7b - | Self::L13b - | Self::L70b - | Self::L7bChat - | Self::L13bChat - | Self::L70bChat - | Self::L7bCode - | Self::L13bCode - | Self::L34bCode - | Self::Leo7b - | Self::Leo13b - | Self::Mixtral - | Self::MixtralInstruct - | Self::Mistral7b - | Self::Mistral7bInstruct - | Self::OpenChat35 - | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, + _ => false, } } pub fn is_open_chat(&self) -> bool { match self { - Self::L7b - | Self::L13b - | Self::L70b - | Self::L7bChat - | Self::L13bChat - | Self::L70bChat - | Self::L7bCode - | Self::L13bCode - | Self::L34bCode - | Self::Leo7b - | Self::Leo13b - | Self::Mixtral - | Self::MixtralInstruct - | Self::Mistral7b - | Self::Mistral7bInstruct - | Self::Zephyr7bAlpha - | Self::Zephyr7bBeta => false, Self::OpenChat35 | Self::Starling7bAlpha => true, + _ => false, + } + } + + pub fn is_phi(&self) -> bool { + match self { + Self::PhiHermes | Self::PhiV1 | Self::PhiV1_5 | Self::PhiV2 => true, + _ => false, } } @@ -153,6 +125,10 @@ impl Models { | Models::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", Models::OpenChat35 => "openchat/openchat_3.5", Models::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", + Models::PhiV1 => "microsoft/phi-1", + Models::PhiV1_5 => "microsoft/phi-1.5", + Models::PhiV2 => "microsoft/phi-2", + Models::PhiHermes => "lmz/candle-quantized-phi", } } @@ -210,6 +186,10 @@ impl Models { "TheBloke/Starling-LM-7B-alpha-GGUF", "starling-lm-7b-alpha.Q4_K_M.gguf", ), + Models::PhiV1 => ("lmz/candle-quantized-phi", "model-v1-q4k.gguf"), + Models::PhiV1_5 => ("lmz/candle-quantized-phi", "model-q4k.gguf"), + Models::PhiV2 => ("lmz/candle-quantized-phi", "model-v2-q4k.gguf"), + Models::PhiHermes => ("lmz/candle-quantized-phi", "model-phi-hermes-1_3B-q4k.gguf"), } } } diff --git a/src/llm/text_generation.rs b/src/llm/text_generation.rs index 0b0c295..0e06e68 100644 --- a/src/llm/text_generation.rs +++ b/src/llm/text_generation.rs @@ -241,8 +241,8 @@ pub fn create_text_generation( model: Models, cache_dir: &Option, ) -> Result> { - let tokenizer = create_tokenizer(model)?; - let model = create_model(model, cache_dir)?; + let tokenizer = create_tokenizer(model).expect("Failed to create tokenizer"); + let model = create_model(model, cache_dir).expect("Failed to create model"); let device = Device::Cpu;