Skip to content

Commit

Permalink
Merge branch 'main' into w8a8-input-scale-none
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 20, 2024
2 parents 63bc660 + 9364f74 commit 22e9855
Show file tree
Hide file tree
Showing 76 changed files with 878 additions and 300 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.905
- name: "exact_match,flexible-extract"
value: 0.905
limit: 1000
num_fewshot: 5
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.769
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.769
value: 0.754
limit: 1000
num_fewshot: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.753
- name: "exact_match,flexible-extract"
value: 0.753
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
7 changes: 7 additions & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,18 @@ trap remove_docker_container EXIT

echo "--- Running container"

HF_CACHE="$(realpath ~)/huggingface"
mkdir -p ${HF_CACHE}
HF_MOUNT="/root/.cache/huggingface"

docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--shm-size=16gb \
--rm \
-e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \
-e HF_HOME=${HF_MOUNT} \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"
Expand Down
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ steps:
mirror_hardwares: [amd]
fast_check: true
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")

#
# Try to find python package with an executable that exactly matches
Expand Down Expand Up @@ -101,7 +101,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& python3 --version

RUN apt-get update -y \
&& apt-get install -y python3-pip git curl
&& apt-get install -y python3-pip git curl libibverbs-dev

# Install pip s.t. it will be compatible with our PYTHON_VERSION
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
Expand Down
60 changes: 35 additions & 25 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"

# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
# Whether to install CK-based flash-attention
# If 0, will not install flash-attention
ARG BUILD_FA="1"
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
# architectures specified in `FA_GFX_ARCHS`
ARG TRY_FA_WHEEL="1"
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2"

# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848"
ARG TRITON_BRANCH="e0fc12c"

### Base image build stage
FROM $BASE_IMAGE AS base
Expand Down Expand Up @@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
ARG APP_MOUNT=/vllm-workspace
WORKDIR ${APP_MOUNT}

RUN pip install --upgrade pip
RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
python3 -m pip uninstall -y torch torchaudio torchvision \
&& python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
Expand All @@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install
&& python3 -m pip wheel . --wheel-dir=/install


### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG TRY_FA_WHEEL
ARG FA_WHEEL_URL
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
# If a suitable wheel exists, we download it instead of building FA
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
else \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
fi; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
Expand Down Expand Up @@ -126,7 +136,7 @@ RUN case "$(which python3)" in \

# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli]
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
Expand All @@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false

RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
python3 -m pip install -Ur requirements-rocm.txt \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
Expand All @@ -153,27 +163,27 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi;
&& python3 -m pip uninstall -y amdsmi;

# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi
&& python3 -m pip uninstall -y triton; fi

# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi
&& python3 -m pip uninstall -y flash-attn; fi

# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
pip install libs/*.whl; fi
python3 -m pip install libs/*.whl; fi

CMD ["/bin/bash"]
6 changes: 3 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
c10::optional<torch::Tensor> const& scale_ub);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
Expand Down
73 changes: 48 additions & 25 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {

#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()

template <typename scalar_t>
template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float const val, float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
Expand Down Expand Up @@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return absmax_val;
}

template <typename scalar_t>
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
Expand All @@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}

Expand All @@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) {
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);

int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

Expand All @@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
__shared__ float token_scale;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();

float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
}
}
Expand Down Expand Up @@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales) {
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

Expand All @@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size);
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}
Loading

0 comments on commit 22e9855

Please sign in to comment.