From a9fea8983bf647fcccbb21b9d97aca0042f80bde Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sun, 25 May 2025 18:42:25 +0800 Subject: [PATCH 1/7] optimize expert parallelism implemented with all2all Signed-off-by: SlightwindSec --- vllm_ascend/quantization/w8a8_dynamic.py | 153 +++++++++++++++++++---- 1 file changed, 129 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 5d2b442cf..b1eea5890 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -29,6 +29,73 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat( + (torch.tensor([True], device=device), assigned_ep_rank[1:] + != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers.float(), + dim=0)[0].to( + temp_start_markers.dtype) + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices + + def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, w1_scale: torch.Tensor, @@ -236,28 +303,48 @@ def fused_experts_with_all2all( expert_idx=topk_ids, active_num=num_tokens) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, + local_buffer_rows = (num_tokens // ep_group.world_size + + 1) * ep_group.world_size * top_k * 2 + max_row_per_ep_rank = local_buffer_rows // ep_group.world_size + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum((expert_idx_buffer_scatter + != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[expert_idx_buffer_scatter != + global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort( + local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( sorted_local_expert_idx, local_num_experts).to(torch.int64) @@ -293,12 +380,30 @@ def fused_experts_with_all2all( group_list_type=group_list_type) if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) + idx_type = sorted_idx.dtype + resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type) hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter + != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, From a0c3e9ba506e8c02b79a1cbe4d3e4daca15eb1d9 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 26 May 2025 10:33:28 +0800 Subject: [PATCH 2/7] [Bugfix] Adjust inputbatch to be compatible with latest vllm (#945) Adjust inputbatch to be compatible with latest vllm, as kvcache group feature has been redo in https://github.com/vllm-project/vllm/pull/18593 --------- Signed-off-by: MengqingCao --- vllm_ascend/attention/attention_v1.py | 11 +++++-- vllm_ascend/attention/mla_v1.py | 11 +++++-- vllm_ascend/worker/model_runner_v1.py | 44 +++++++++------------------ 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 36ac97207..61e26e19d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -30,6 +30,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm_ascend.ops.attention import vanilla_chunked_prefill +from vllm_ascend.utils import vllm_version_is class AscendAttentionBackend(AttentionBackend): @@ -141,8 +142,14 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs, num_actual_tokens, max_query_len, common_prefix_len): - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table = (self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]) + else: + block_table = self.runner.input_batch.block_table[ + 0].get_device_tensor() + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) query_lens = self.runner.query_lens seq_lens = self.runner.seq_lens_cpu[:num_reqs] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d987eaba9..eb40f4115 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,6 +16,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla +from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: @@ -238,8 +239,14 @@ def build(self, # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table = (self.runner.input_batch.block_table. + get_device_tensor()[:num_reqs]) + else: + block_table = self.runner.input_batch.block_table[ + 0].get_device_tensor() + block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( + block_table[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2ee742610..91f81956e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -114,6 +114,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config @@ -172,24 +173,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): raise NotImplementedError( "Non-Attention backend is not supported by V1 NPUModelRunner.") - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) @@ -237,16 +220,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): pin_memory=True, vocab_size=self.model_config.get_vocab_size(), ) - else: - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=True, - vocab_size=self.model_config.get_vocab_size(), - ) self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, @@ -600,7 +573,10 @@ def _process_reqs( block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + else: + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, @@ -1206,6 +1182,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ import torch_npu kv_caches: Dict[str, torch.Tensor] = {} + if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")): + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=True, + vocab_size=self.model_config.get_vocab_size(), + block_size=self.cache_config.block_size, + ) for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec From 01e3d59eae28be39df772dfb6c0113bf9e1552e9 Mon Sep 17 00:00:00 2001 From: Shuqiao Li Date: Mon, 26 May 2025 14:18:26 +0800 Subject: [PATCH 3/7] add workflow to build and release wheel (#775) ### What this PR does / why we need it? This is a continuing work of #716. This PR add workflow to build and release wheel, and also release source to PYPI. We have 3 conditions to trigger the workflow: 1. PR to `main` and `*-dev` 2. push to `main` and `*-dev` 3. push tag with name of `v*` Release to PYPI will only be done under condition 3. Under condition 1 and 2, it will generate .tar.gz and build .whl, upload to github artifacts but will not release. update: Will build .whl and upload to github artifacts with scheduled task. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? All triggered conditions are well tested with my fork repo. --------- Signed-off-by: Shuqiao Li Signed-off-by: Yikun Jiang Co-authored-by: Yikun Jiang --- .github/Dockerfile.buildwheel | 48 +++++++++++++++ .github/actionlint.yaml | 1 + .github/workflows/actionlint.yml | 2 +- .github/workflows/release_code.yml | 87 +++++++++++++++++++++++++++ .github/workflows/release_whl.yml | 95 ++++++++++++++++++++++++++++++ tools/actionlint.sh | 2 +- 6 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 .github/Dockerfile.buildwheel create mode 100644 .github/workflows/release_code.yml create mode 100644 .github/workflows/release_whl.yml diff --git a/.github/Dockerfile.buildwheel b/.github/Dockerfile.buildwheel new file mode 100644 index 000000000..dfe8a63f6 --- /dev/null +++ b/.github/Dockerfile.buildwheel @@ -0,0 +1,48 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +ARG PY_VERSION=3.10 +FROM quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py${PY_VERSION} + +ARG COMPILE_CUSTOM_KERNELS=1 + +# Define environments +ENV DEBIAN_FRONTEND=noninteractive +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} +RUN apt-get update -y && \ + apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \ + rm -rf /var/cache/apt/* && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY . /workspace/vllm-ascend/ + +# Install req +RUN python3 -m pip install -r vllm-ascend/requirements.txt --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip install twine + +# Install vllm-ascend +RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + cd vllm-ascend && \ + python3 setup.py bdist_wheel && \ + ls -l dist && \ + for f in dist/*.whl; do mv "$f" "$(echo "$f" | sed -e 's/-linux_x86_64\.whl$/-manylinux1_x86_64.whl/' -e 's/-linux_aarch64\.whl$/-manylinux2014_aarch64.whl/')"; done && \ + ls -l dist + +CMD ["/bin/bash"] diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 14c73e817..78ea6f3bd 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -5,3 +5,4 @@ self-hosted-runner: - linux-arm64-npu-2 - linux-arm64-npu-4 - linux-arm64-npu-static-8 + - ubuntu-24.04-arm diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml index daeb9d01e..91cd9c412 100644 --- a/.github/workflows/actionlint.yml +++ b/.github/workflows/actionlint.yml @@ -47,7 +47,7 @@ jobs: - name: "Run actionlint" env: - SHELLCHECK_OPTS: --exclude=SC2046,SC2006 + SHELLCHECK_OPTS: --exclude=SC2046,SC2006,SC2086 run: | echo "::add-matcher::.github/workflows/matchers/actionlint.json" tools/actionlint.sh -color diff --git a/.github/workflows/release_code.yml b/.github/workflows/release_code.yml new file mode 100644 index 000000000..b340cc492 --- /dev/null +++ b/.github/workflows/release_code.yml @@ -0,0 +1,87 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +name: build / sdist + +on: + pull_request: + branches: + - 'main' + - '*-dev' + paths: + - '.github/workflows/release_code.yml' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + push: + branches: + - 'main' + - '*-dev' + tags: + - 'v*' + paths: + - '.github/workflows/release_code.yml' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + +jobs: + build: + name: release code + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Print + run: | + lscpu + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python3 -m pip install twine setuptools_scm + + - name: Generate tar.gz + run: | + python3 setup.py sdist + ls dist + + - name: Archive tar.gz + uses: actions/upload-artifact@v4 + with: + name: vllm-ascend-src + path: dist/* + + - name: Release + if: startsWith(github.ref, 'refs/tags/') + run: | + python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml new file mode 100644 index 000000000..f66a01588 --- /dev/null +++ b/.github/workflows/release_whl.yml @@ -0,0 +1,95 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +name: build / wheel + +on: + pull_request: + branches: + - 'main' + - '*-dev' + paths: + - '.github/workflows/release_whl.yml' + - '.github/Dockerfile.buildwheel' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + push: + branches: + - 'main' + - '*-dev' + tags: + - 'v*' + paths: + - '.github/workflows/release_whl.yml' + - '.github/Dockerfile.buildwheel' + - 'vllm_ascend/**' + - 'setup.py' + - 'pyproject.toml' + - 'requirements.txt' + - 'cmake/**' + - 'CMakeLists.txt' + - 'csrc/**' + +jobs: + build: + name: build and release wheel + strategy: + matrix: + os: [ubuntu-24.04, ubuntu-24.04-arm] + python-version: ['3.9', '3.10', '3.11'] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Print + run: | + lscpu + + - name: Build wheel + run: | + ls + docker build -f ./.github/Dockerfile.buildwheel \ + --build-arg PY_VERSION=${{ matrix.python-version }} \ + -t wheel:v1 . + docker run --rm \ + -v $(pwd):/outpwd \ + wheel:v1 \ + bash -c "cp -r /workspace/vllm-ascend/dist /outpwd" + ls dist + + - name: Archive wheel + uses: actions/upload-artifact@v4 + with: + name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel + path: dist/* + + - name: Set up Python ${{ matrix.python-version }} + if: startsWith(github.ref, 'refs/tags/') + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Release + if: startsWith(github.ref, 'refs/tags/') + run: | + python3 -m pip install twine + python3 -m twine upload --verbose dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/tools/actionlint.sh b/tools/actionlint.sh index a050b568b..d1950db5c 100755 --- a/tools/actionlint.sh +++ b/tools/actionlint.sh @@ -18,7 +18,7 @@ # This file is a part of the vllm-ascend project. # Adapted from https://github.com/vllm-project/vllm/tree/main/tools # -export SHELLCHECK_OPTS="--exclude=SC2046,SC2006" +export SHELLCHECK_OPTS="--exclude=SC2046,SC2006,SC2086" if command -v actionlint &> /dev/null; then actionlint .github/workflows/*.yml .github/workflows/*.yaml From 9f5ab59e307a66fd0b17916218c9387c328b1c59 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Tue, 27 May 2025 15:16:17 +0800 Subject: [PATCH 4/7] [WIP][BugFix]Fix accuracy issues caused by wrong etp_size passed into FusedMoEParallelConfig when using vLLM 0.9.0 (#961) ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/ops/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6313a7506..bc3b86b65 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -748,6 +748,7 @@ def __init__( vllm_parallel_config=vllm_config.parallel_config)) self.moe_parallel_config.ep_size = get_ep_group().world_size + self.moe_parallel_config.tp_size = get_etp_group().world_size self.top_k = top_k self.num_experts = num_experts From c9fa218423ecc9e72cabbae6b51f8140e0eda9d2 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 27 May 2025 23:05:21 +0800 Subject: [PATCH 5/7] add VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER feature --- vllm_ascend/envs.py | 2 + vllm_ascend/ops/fused_moe.py | 1 + vllm_ascend/quantization/w8a8_dynamic.py | 243 ++++++++++++++++------- 3 files changed, 176 insertions(+), 70 deletions(-) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8e1cc1c16..39d81f4ee 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -36,6 +36,8 @@ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER": + lambda: bool(int(os.getenv("VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER", '0'))), "USING_LCCL_COM": lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 01da5bef3..eb64f31bc 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -614,6 +614,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + self.max_model_len = vllm_config.model_config.max_model_len if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b1eea5890..8eaf73240 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -27,6 +27,7 @@ from vllm_ascend.ops.fused_moe import select_experts VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER: bool = envs_ascend.VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, @@ -58,16 +59,13 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat( - (torch.tensor([True], device=device), assigned_ep_rank[1:] - != assigned_ep_rank[:-1])) + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) temp_start_markers[is_new_segment] = indices_arange[is_new_segment] - start_offset_for_each_token = torch.cummax(temp_start_markers.float(), - dim=0)[0].to( - temp_start_markers.dtype) + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) @@ -303,48 +301,28 @@ def fused_experts_with_all2all( expert_idx=topk_ids, active_num=num_tokens) - local_buffer_rows = (num_tokens // ep_group.world_size + - 1) * ep_group.world_size * top_k * 2 - max_row_per_ep_rank = local_buffer_rows // ep_group.world_size - expert_idx_buffer_scatter, unpad_indices = process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) - hidden_states_pad_idx = torch.zeros( - expert_idx_buffer_scatter.shape, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[expert_idx_buffer_scatter != - global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) - expert_idx_buffer_gather = torch.empty_like( - expert_idx_buffer_scatter, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - dist.all_to_all_single(expert_idx_buffer_gather, - expert_idx_buffer_scatter, + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) + + gather_sizes = torch.empty_like(scatter_sizes) + dist.all_to_all_single(gather_sizes, + scatter_sizes, group=ep_group.device_group) - hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + scatter_size_list = scatter_sizes.cpu().tolist() + gather_size_list = gather_sizes.cpu().tolist() - hidden_states_buffer_gather = torch.empty_like( - hidden_states_buffer_scatter, - dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) + expanded_expert_idx = expanded_expert_idx % local_num_experts + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) + + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) - mask = expert_idx_buffer_gather != global_num_experts - local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) - hidden_states = hidden_states_buffer_gather[mask] - idx_type = local_expert_idx.dtype - sorted_local_expert_idx, sorted_idx = torch.sort( - local_expert_idx.float()) - sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( sorted_local_expert_idx, local_num_experts).to(torch.int64) @@ -380,30 +358,12 @@ def fused_experts_with_all2all( group_list_type=group_list_type) if expert_map is not None: - idx_type = sorted_idx.dtype - resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type) + resorted_idx = torch.argsort(sorted_idx) hidden_states = hidden_states[resorted_idx] - hidden_states_scatter = torch.zeros( - (mask.shape[0], hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states_scatter[mask] = hidden_states - hidden_states_gatter = torch.empty_like( - hidden_states_scatter, - dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) - dist.all_to_all_single(hidden_states_gatter, - hidden_states_scatter, - group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter - != global_num_experts] - if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states[unpad_indices != -1] = hidden_states_gatter - else: - hidden_states = hidden_states_gatter + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) + final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, @@ -430,6 +390,133 @@ def fused_experts_with_all2all( return final_hidden_states +def fused_experts_with_all2all_with_fixed_buffer( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + local_buffer_rows = (max_model_len // ep_group.world_size + + 1) * ep_group.world_size * top_k * 2 + max_row_per_ep_rank = local_buffer_rows // ep_group.world_size + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -687,7 +774,6 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = True, - dp_size: int = 1, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -740,7 +826,7 @@ def apply( top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name) - elif dp_size == 1: + elif self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, @@ -750,7 +836,24 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + elif VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER and expert_map is not None: + return fused_experts_with_all2all_with_fixed_buffer( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=layer.max_model_len, + expert_map=expert_map, + ep_group=self.ep_group) else: + # The current implementation of deepseek moe splits hidden_states + # according to tp_size before they are feed into fused_moe module. + # Therefore, all2all is needed no matter how dp/tp is set so as to + # dispatch/combine tokens. return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, From 8f3ec63eb116cdb80b2db9e5c6d5f92ea5e5b093 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sun, 25 May 2025 18:42:25 +0800 Subject: [PATCH 6/7] optimize expert parallelism implemented with all2all Signed-off-by: SlightwindSec --- vllm_ascend/quantization/w8a8_dynamic.py | 153 +++++++++++++++++++---- 1 file changed, 129 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b012f..b262d3caa 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -29,6 +29,73 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat( + (torch.tensor([True], device=device), assigned_ep_rank[1:] + != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers.float(), + dim=0)[0].to( + temp_start_markers.dtype) + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices + + def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w1: torch.Tensor, w1_scale: torch.Tensor, @@ -236,28 +303,48 @@ def fused_experts_with_all2all( expert_idx=topk_ids, active_num=num_tokens) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, + local_buffer_rows = (num_tokens // ep_group.world_size + + 1) * ep_group.world_size * top_k * 2 + max_row_per_ep_rank = local_buffer_rows // ep_group.world_size + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum((expert_idx_buffer_scatter + != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[expert_idx_buffer_scatter != + global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort( + local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( sorted_local_expert_idx, local_num_experts).to(torch.int64) @@ -293,12 +380,30 @@ def fused_experts_with_all2all( group_list_type=group_list_type) if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) + idx_type = sorted_idx.dtype + resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type) hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter + != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, From b758c2723acd9228dd87afa324ffe97647abcb0c Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 27 May 2025 23:05:21 +0800 Subject: [PATCH 7/7] add VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER feature --- vllm_ascend/envs.py | 2 + vllm_ascend/ops/fused_moe.py | 1 + vllm_ascend/quantization/w8a8_dynamic.py | 236 ++++++++++++++++------- 3 files changed, 171 insertions(+), 68 deletions(-) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8e1cc1c16..39d81f4ee 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -36,6 +36,8 @@ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER": + lambda: bool(int(os.getenv("VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER", '0'))), "USING_LCCL_COM": lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index bc3b86b65..f2214a9f6 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -767,6 +767,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + self.max_model_len = vllm_config.model_config.max_model_len if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b262d3caa..8eaf73240 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -27,6 +27,7 @@ from vllm_ascend.ops.fused_moe import select_experts VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER: bool = envs_ascend.VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, @@ -58,16 +59,13 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat( - (torch.tensor([True], device=device), assigned_ep_rank[1:] - != assigned_ep_rank[:-1])) + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) temp_start_markers[is_new_segment] = indices_arange[is_new_segment] - start_offset_for_each_token = torch.cummax(temp_start_markers.float(), - dim=0)[0].to( - temp_start_markers.dtype) + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) @@ -303,48 +301,28 @@ def fused_experts_with_all2all( expert_idx=topk_ids, active_num=num_tokens) - local_buffer_rows = (num_tokens // ep_group.world_size + - 1) * ep_group.world_size * top_k * 2 - max_row_per_ep_rank = local_buffer_rows // ep_group.world_size - expert_idx_buffer_scatter, unpad_indices = process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) - hidden_states_pad_idx = torch.zeros( - expert_idx_buffer_scatter.shape, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[expert_idx_buffer_scatter != - global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) - expert_idx_buffer_gather = torch.empty_like( - expert_idx_buffer_scatter, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - dist.all_to_all_single(expert_idx_buffer_gather, - expert_idx_buffer_scatter, + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) + + gather_sizes = torch.empty_like(scatter_sizes) + dist.all_to_all_single(gather_sizes, + scatter_sizes, group=ep_group.device_group) - hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + scatter_size_list = scatter_sizes.cpu().tolist() + gather_size_list = gather_sizes.cpu().tolist() - hidden_states_buffer_gather = torch.empty_like( - hidden_states_buffer_scatter, - dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) + expanded_expert_idx = expanded_expert_idx % local_num_experts + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) + + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) - mask = expert_idx_buffer_gather != global_num_experts - local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) - hidden_states = hidden_states_buffer_gather[mask] - idx_type = local_expert_idx.dtype - sorted_local_expert_idx, sorted_idx = torch.sort( - local_expert_idx.float()) - sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( sorted_local_expert_idx, local_num_experts).to(torch.int64) @@ -380,30 +358,12 @@ def fused_experts_with_all2all( group_list_type=group_list_type) if expert_map is not None: - idx_type = sorted_idx.dtype - resorted_idx = torch.argsort(sorted_idx.float()).to(idx_type) + resorted_idx = torch.argsort(sorted_idx) hidden_states = hidden_states[resorted_idx] - hidden_states_scatter = torch.zeros( - (mask.shape[0], hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states_scatter[mask] = hidden_states - hidden_states_gatter = torch.empty_like( - hidden_states_scatter, - dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) - dist.all_to_all_single(hidden_states_gatter, - hidden_states_scatter, - group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter - != global_num_experts] - if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states[unpad_indices != -1] = hidden_states_gatter - else: - hidden_states = hidden_states_gatter + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) + final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, @@ -430,6 +390,133 @@ def fused_experts_with_all2all( return final_hidden_states +def fused_experts_with_all2all_with_fixed_buffer( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + local_buffer_rows = (max_model_len // ep_group.world_size + + 1) * ep_group.world_size * top_k * 2 + max_row_per_ep_rank = local_buffer_rows // ep_group.world_size + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -749,6 +836,19 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + elif VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER and expert_map is not None: + return fused_experts_with_all2all_with_fixed_buffer( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=layer.max_model_len, + expert_map=expert_map, + ep_group=self.ep_group) else: # The current implementation of deepseek moe splits hidden_states # according to tp_size before they are feed into fused_moe module.