Skip to content

Commit

Permalink
Optimize Mixtral with expert parallelism (#2090)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Dec 14, 2023
1 parent f1c8520 commit 21d93c1
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 334 deletions.
14 changes: 1 addition & 13 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads

RUN python3 setup.py build_ext --inplace

# Build the megablocks library as wheel because it doesn't publish pre-built wheels.
# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78
RUN apt-get install -y git && \
git clone https://github.com/stanford-futuredata/megablocks.git && \
cd megablocks && \
git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \
MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel

# image to run unit testing suite
FROM dev AS test

Expand Down Expand Up @@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate

COPY vllm vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY --from=build /workspace/megablocks/dist/*.whl /tmp/
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \
rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash
pip install vllm
```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks):
```bash
pip install megablocks
```

## Getting Started

Expand Down
3 changes: 1 addition & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for in
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.

.. note::
Currently, the ROCm version of vLLM does not support Mixtral.
Additionally, it only supports Mistral for context lengths up to 4096.
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.

.. tip::
The easiest way to check if your model is supported is to run the program below:
Expand Down
16 changes: 9 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ def _verify_load_format(self) -> None:
if load_format == "auto":
load_format = "pt"

# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt":
logger.info(
"Currently, only 'pt' format is supported for Mixtral. "
"Changing the format to 'pt'. This may re-download the "
"weights if you have downloaded the safetensor weights.")
load_format = "pt"
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"

self.load_format = load_format

Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"]
_ROCM_UNSUPPORTED_MODELS = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
}


Expand Down
Loading

0 comments on commit 21d93c1

Please sign in to comment.