Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer dtype from model config when not explicitly specified #534

Merged
merged 10 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/base_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
- 🔄 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
- 💎 [Gemma](https://blog.google/technology/developers/gemma-open-models/)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- 🏛️ [Phi-3](https://azure.microsoft.com/en-us/blog/introducing-phi-3-redefining-whats-possible-with-slms/) / [Phi-2](https://huggingface.co/microsoft/phi-2)
- 🔮 [Qwen2 / Qwen](https://huggingface.co/Qwen)
- 🗣️ [Command-R](https://docs.cohere.com/docs/command-r)
Expand Down
1 change: 1 addition & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl std::fmt::Display for Quantization {

#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
#[clap(name = "float16")]
Float16,
#[clap(name = "bfloat16")]
BFloat16,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def serve(
dtype = None if dtype is None else dtype.value
if dtype is not None and quantize is not None:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
"Only 1 can be set between `dtype` and `quantize`, as they both decide how the final model is initialized."
)
server.serve(
model_id,
Expand Down
20 changes: 15 additions & 5 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.t5 import T5Sharded
from lorax_server.utils.sources import get_s3_model_local_dir
from lorax_server.utils.torch_utils import is_bf16_supported

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
Expand Down Expand Up @@ -72,15 +73,24 @@ def get_model(
raise ValueError(f"Unknown source {source}")

model_type = config_dict["model_type"]
is_dtype_provided = dtype is not None
dtype = dtype or config_dict.get("torch_dtype", "float16")

if dtype is None:
dtype = torch.float16
elif dtype == "float16":
if dtype in {"float16", "float32"}:
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
if not is_bf16_supported():
if is_dtype_provided:
raise RuntimeError("bfloat16 is not supported on this device, set --dtype float16.")
logger.warning("bfloat16 is not supported on this device, falling back to float16")
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
try:
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved
dtype = getattr(torch, dtype)
except AttributeError:
raise RuntimeError(f"Unknown dtype {dtype}")

if "facebook/galactica" in model_id:
return GalacticaSharded(
Expand Down
10 changes: 10 additions & 0 deletions server/lorax_server/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def is_bf16_supported() -> bool:
"""Check if the current GPU supports bfloat16.

Returns:
True if supported, False otherwise.
"""
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
Loading