Skip to content

Commit

Permalink
Merge branch 'main' into fix_guided_dec_with_mistral_tokenizer_mode
Browse files Browse the repository at this point in the history
Signed-off-by: Wallas Santos <[email protected]>
  • Loading branch information
wallashss committed Dec 17, 2024
2 parents 0173c2d + f9ecbb1 commit b98f633
Show file tree
Hide file tree
Showing 181 changed files with 7,589 additions and 2,567 deletions.
25 changes: 25 additions & 0 deletions .buildkite/run-gh200-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# This script build the GH200 docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex

# Try building the docker image
DOCKER_BUILDKIT=1 docker build . \
--target vllm-openai \
--platform "linux/arm64" \
-t gh200-test \
--build-arg max_jobs=66 \
--build-arg nvcc_threads=2 \
--build-arg torch_cuda_arch_list="9.0+PTX" \
--build-arg vllm_fa_cmake_gpu_arches="90-real"

# Setup cleanup
remove_docker_container() { docker rm -f gh200-test || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image and test offline inference
docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
python3 examples/offline_inference.py
'
9 changes: 6 additions & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ steps:
commands:
- VLLM_USE_V1=1 pytest -v -s v1

- label: Examples Test # 15min
- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/entrypoints
- examples/
commands:
- pip install awscli tensorizer # for llava example and tensorizer test
- pip install tensorizer # for tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
Expand All @@ -198,7 +198,10 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_profile.py --model facebook/opt-125m
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2

- label: Prefix Caching Test # 9min
mirror_hardwares: [amd]
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
Expand Down Expand Up @@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
Expand Down
40 changes: 32 additions & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ARG CUDA_VERSION=12.4.1
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM
ENV DEBIAN_FRONTEND=noninteractive

# Install Python and other dependencies
Expand Down Expand Up @@ -46,9 +47,14 @@ WORKDIR /workspace
# install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt
COPY requirements-cuda.txt requirements-cuda.txt
COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt

RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
python3 -m pip install -r requirements-cuda-arm64.txt; \
fi

# cuda arch list used by torch
# can be useful for both `dev` and `test`
Expand All @@ -63,13 +69,19 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}

#################### WHEEL BUILD IMAGE ####################
FROM base AS build
ARG TARGETPLATFORM

# install build dependencies
COPY requirements-build.txt requirements-build.txt

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt

RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
python3 -m pip install -r requirements-cuda-arm64.txt; \
fi

COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
Expand Down Expand Up @@ -134,15 +146,18 @@ COPY requirements-test.txt requirements-test.txt
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

#################### DEV IMAGE ####################

#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM

COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt

RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
Expand All @@ -168,18 +183,25 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/

# install vllm wheel first, so that torch etc will be installed
# Install vllm wheel first, so that torch etc will be installed.
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
pip uninstall -y torch && \
python3 -m pip install -r requirements-cuda-arm64.txt; \
fi

RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
fi
COPY examples examples
#################### vLLM installation IMAGE ####################


#################### TEST IMAGE ####################
# image to run unit testing suite
# note that this uses vllm installed by `pip`
Expand Down Expand Up @@ -209,7 +231,6 @@ COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
RUN mkdir test_docs
RUN mv docs test_docs/
RUN mv vllm test_docs/

#################### TEST IMAGE ####################

#################### OPENAI API SERVER ####################
Expand All @@ -218,8 +239,11 @@ FROM vllm-base AS vllm-openai

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10

if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
else \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
fi
ENV VLLM_USAGE_SOURCE production-docker-image

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
Expand Down
12 changes: 12 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def main(args: argparse.Namespace):
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode

if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
Expand All @@ -790,6 +791,7 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}"

tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code)

if args.dataset is not None:
Expand Down Expand Up @@ -1210,5 +1212,15 @@ def main(args: argparse.Namespace):
"from the sampled HF dataset.",
)

parser.add_argument(
'--tokenizer-mode',
type=str,
default="auto",
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')

args = parser.parse_args()
main(args)
173 changes: 173 additions & 0 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import pickle as pkl
import time
from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm


@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
add_residual: bool
dtype: torch.dtype

def description(self):
return (f'N {self.num_tokens} '
f'x D {self.hidden_size} '
f'x R {self.add_residual} '
f'x DT {self.dtype}')


def get_bench_params() -> List[bench_params_t]:
## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]

combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
return bench_params


# Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _, _ = ops.scaled_int8_quant(torch_out)


def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = ops.scaled_fp8_quant(torch_out)


def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual=residual)


# Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
quant_dtype: torch.dtype, label: str, sub_label: str,
fn: Callable, description: str) -> TMeasurement:

min_run_time = 1

globals = {
"rms_norm_layer": rms_norm_layer,
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)

def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:

# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens,
params.hidden_size,
dtype=params.dtype,
device='cuda') * scale
residual = (torch.randn_like(x) * scale).to(device='cuda') \
if params.add_residual else None

timers = []

# unfused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label,
unfused_int8_impl, "unfused_int8_impl"))

# unfused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
unfused_fp8_impl, "unfused_fp8_impl"))

# fused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
"fused_int8_impl"))

# fused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
fused_impl, "fused_fp8_impl"))

print_timers(timers)

return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def main():
torch.set_default_device('cuda')
bench_params = get_bench_params()

timers = []
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)

# pickle all the results
timestamp = int(time.time())
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
pkl.dump(timers, f)


if __name__ == '__main__':
main()
Loading

0 comments on commit b98f633

Please sign in to comment.