Skip to content

Commit

Permalink
Added Gemma (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Feb 21, 2024
1 parent b407adc commit 1ade8b8
Show file tree
Hide file tree
Showing 10 changed files with 852 additions and 91 deletions.
22 changes: 16 additions & 6 deletions docs/models/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ Any combination of linear layers can be targeted in the adapters, which correspo
- `o_proj`
- `lm_head`

### Qwen
### Gemma

- `c_attn`
- `c_proj`
- `w1`
- `w2`
- `lm_head`
- `q_proj`
- `k_proj`
- `v_proj`
- `o_proj`
- `gate_proj`
- `up_proj`
- `down_proj`

### Phi

Expand All @@ -54,6 +56,14 @@ Any combination of linear layers can be targeted in the adapters, which correspo
- `fc2`
- `lm_head`

### Qwen

- `c_attn`
- `c_proj`
- `w1`
- `w2`
- `lm_head`

### GPT2

- `c_attn`
Expand Down
3 changes: 2 additions & 1 deletion docs/models/base_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
- 🌬️[Mistral](https://huggingface.co/mistralai)
- [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
- 🔄 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
- 🔮 [Qwen](https://huggingface.co/Qwen)
- 💎 [Gemma](https://blog.google/technology/developers/gemma-open-models/)
- 🏛️ [Phi](https://huggingface.co/microsoft/phi-2)
- 🔮 [Qwen](https://huggingface.co/Qwen)
- 🤖 [GPT2](https://huggingface.co/gpt2)
- 🌸 [Bloom](https://huggingface.co/bigscience/bloom)

Expand Down
16 changes: 16 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from lorax_server.models.flash_rw import FlashRWSharded
from lorax_server.models.flash_neox import FlashNeoXSharded
from lorax_server.models.flash_llama import FlashLlama
from lorax_server.models.flash_gemma import FlashGemma
from lorax_server.models.flash_gpt2 import FlashGPT2
from lorax_server.models.flash_qwen import FlashQwen
from lorax_server.models.flash_phi import FlashPhi
Expand All @@ -66,6 +67,7 @@
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashGemma)
__all__.append(FlashGPT2)
__all__.append(FlashQwen)
__all__.append(FlashPhi)
Expand Down Expand Up @@ -361,6 +363,20 @@ def get_model(
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Phi model requires flash attention v2")

if model_type == "gemma":
if FLASH_ATTENTION:
return FlashGemma(
model_id,
adapter_id,
adapter_source,
revision,
quantize=quantize,
compile=compile,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Gemma model requires flash attention v2")

if model_type == "opt":
return OPTSharded(
Expand Down
Loading

0 comments on commit 1ade8b8

Please sign in to comment.