Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconcile merge differences [fix Custom All Reduce; remove Torchrun & Cython] #163

Closed
wants to merge 0 commits into from

Conversation

mawong-amd
Copy link

@mawong-amd mawong-amd commented Sep 3, 2024

Apart from small miscellaneous fixes, this PR also contains the following major changes:

  1. Fixes and enables custom all reduce on ROCm.
  2. Removes torchrun executor as it does not provide performance benefits over the now-default multiprocessing.
  3. Removes Cython-ization of sampler related files as sampler has seen performance benefits in upstream.

Copy link

github-actions bot commented Sep 3, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mawong-amd
Copy link
Author

/ready

@mawong-amd mawong-amd changed the title Reconcile merge diffs Reconcile merge differences [Custom All Reduce, Triton FA, remove torchrun & cython] Sep 3, 2024
@mawong-amd mawong-amd changed the title Reconcile merge differences [Custom All Reduce, Triton FA, remove torchrun & cython] Reconcile merge differences [fix Custom All Reduce & Triton FA, remove Torchrun & Cython] Sep 3, 2024
@mawong-amd mawong-amd requested a review from gshtras September 3, 2024 11:52
@mawong-amd mawong-amd changed the title Reconcile merge differences [fix Custom All Reduce & Triton FA, remove Torchrun & Cython] Reconcile merge differences [fix Custom All Reduce & Triton FA; remove Torchrun & Cython] Sep 3, 2024
@mawong-amd mawong-amd force-pushed the mawong/reconcile_merge branch from 922e143 to a61a72a Compare September 3, 2024 12:37
Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall good additions and a few bugs caught, thanks
A couple of suggestions though, plus, let's move on to rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 as the default base, and leave only building rccl (from the rocm-6.2.0 tag) and triton, as it is buggy there lol

Dockerfile.rocm Outdated
ENV CCACHE_DIR=/root/.cache/ccache

RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation here?

Copy link
Author

@mawong-amd mawong-amd Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream CI added in ccache previously to try to cut down on the vLLM compile time. It's not important to have it locally but might be marginally useful for internal CI. We should switch to sccache though as it's more supported by ROCm components... but truthfully this entire thing is not important. It's changed here because that's what's in upstream.

Dockerfile.rocm Outdated
fi
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to bundle huggingface-cli?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No huge reason, it was done previously so we can conveniently do huggingface-cli login to access gated models needed for some tests.

int M = in_a.size(0);
int K = in_a.size(1);
void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block = 4) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this default, it doesn't make sense.

int K = in_a.size(1);
void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t N_in, const int64_t CuCount) {
auto M = in_a.size(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream has a newer version of awq according to @rasmith

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the upstream PR but it doesn't seem to include this PR: #146? Or is this later PR not necessary?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upstream PR came after this and should override what we currently have.

@@ -207,6 +182,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
# print(f"awq_dequantize:qweight.shape = {qweight.shape}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we want the clean implementation that was upstreamed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to revert to the original more performant implementation. The current version is a workaround to the now fixed compiler bug that was never upstreamed, so taking the upstream version is the way to go here

@@ -173,8 +174,9 @@ async def _check_model(request: Union[CompletionRequest,

async def _guided_decode_logits_processor(request, tokenizer):
decoding_config = runner.engine_config.decoding_config
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
assert decoding_config is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking not required if it arrives in the request

Copy link
Author

@mawong-amd mawong-amd Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah oops the logic for the assert as written currently isn't right...

For context, this was done because the lint complained that decoding_config could be None (whence it doesn't have the guided_decoding_backend attribute).
Perfunctory read seems to indicate that that's not possible because of the default initialization. Unfortunately the linter can't trace types that far.

So either we do a # type: ignore or do this assert only if request.guided_decoding_backend is a dud.

vllm/sequence.py Outdated Show resolved Hide resolved
@@ -57,7 +58,25 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,

scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales
return (iweights - zeros) * scales, zeros
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is old, I thought @gshtras was bringing in the new stuff from upstream.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is old, I thought @gshtras was bringing in the new stuff from upstream.

Yeah something's a bit strange here... Having a newer version of this line misled me into thinking PR#146 is newer than upstream. Will re-inspect the AWQ merge.

f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}"
f" scales_rows = {scales_rows} scales_cols = {scales_cols}")
weights = torch_awq_dequantize(qweight, scales, qzeros)
return torch.matmul(input, weights)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also old, there is newer stuff upstream. This is also not listed in the points 1-4 in the description. Did you mean to change the awq stuff? If you do, please bring in the new changes from upstream.

@@ -207,6 +182,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
# print(f"awq_dequantize:qweight.shape = {qweight.shape}"
# f"scales = {scales.shape},"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old

@@ -217,6 +198,12 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
# if input.shape[0] > 1:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old

vllm/envs.py Outdated
@@ -444,6 +444,10 @@ def get_default_config_root():
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# If set, vLLM will use Triton implementations of AWQ.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is old. Get the new stuff from upstream.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The upstream PR came after this and should override what we currently have.

std::vector<uint8_t> get_meta_buffer_ipc_handle(torch::Tensor inp) {
std::vector<uint8_t> data_handle(sizeof(cudaIpcMemHandle_t), 0);
CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data(),
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on this? Why the change to tensor and why its python counterpart still marks the return type as List[str]

Copy link
Author

@mawong-amd mawong-amd Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using namespace std;

Why change to torch::Tensor: torch_bindings.cpp doesn't compile if the return type is vector<uint8_t> because Torch Extension does not support binding this type to Python. This problem was hidden before this PR/cherry-pick at because we did not compile the CAR extension and associated bindings on ROCm.

Why List[str]: was a problem carried over in the original PR. Looks like they first tried vector<string> return instead of torch::Tensor and forgot to change the type hint afterwards. I assumed some implicit conversion business was going on during the binding but on second thought, clearly not. The type hints should be converted to torch.Tensor.

The status quo isn't strictly speaking correct, but it does follow upstream convention and works at present. As per the comments in the original PR, if this is wrong, we are likely to see evidence only when torch.compile is used more liberally and we start seeing problems here.

Extra context:
Some aspects of the Torch Extension system are not easy to understand and are not well-documented. I'm not convinced that the upstream PR gets everything right: for instance, it often does something like the following

  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
  custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, &get_graph_buffer_ipc_meta);

But it doesn't seem the second line is necessary: the original .def with the optional function pointer supplied already (implicitly) registers this method for all dispatch keys (including torch::kCUDA, torch::kCPU, etc.). The subsequent .impl registers this method with the CPU dispatch key and hence appears to be redundant (see below comment for why: the .def with function pointer actually bypasses the dispatch system, so registered dispatch keys aren't used).

In general we would expect .impl to be redundant if we are not intending on having different implementations of a method depending on the device (e.g. different CPU and GPU methods for computing sine of a Tensor). Put another way, we should never need to use .impl if we only ever have/expect a single implementation.

For CAR, which is currently NV/AMD GPU-only, the developer knows at compile-time whether a Tensor is on CPU or GPU, and accordingly we only have one implementation of each method. So we don't need the .impl-s. This might change in the future if say someone generalizes CAR to also work on CPU-only vLLM, and hence they need to provide implementations of these methods for the cases where some of the all reduce buffers are on CPU (instead of being on GPU).

Copy link
Author

@mawong-amd mawong-amd Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some documentation for details about custom operator dispatch: https://pytorch.org/cppdocs/api/file_torch_library.h.html#file-torch-library-h
https://pytorch.org/cppdocs/library.html#classtorch_1_1_library_1a59330802ab9deaff247d32a22092bdfb

  • How/when dispatching happens:

When you pass operators with tensors of your custom backend [editor note: corresponding to a particular dispatch key], your overridden implementations will be called instead of the standard implementations

Here is even more detailed information, good for understanding context but some mechanical details about the dispatch system are deprecated.

  • Using .def with a function pointer, with/without a schema:

Define an operator for a schema and then register an implementation for it.
This is typically what you would use if you aren’t planning on making use of the dispatcher to structure your operator implementation. It’s roughly equivalent to calling def() and then impl(), but if you omit the schema of the operator, we will infer it from the type of your C++ function. All template arguments are inferred.

  • Using .def without a function pointer, schema necessary:

Declare an operator with a schema, but don’t provide any implementations for it.
You’re expected to then provide implementations using the impl() method. All template arguments are inferred.

  • Using .impl:

Register an implementation for an operator.
You may register multiple implementations for a single operator at different dispatch keys (see torch::dispatch()). Implementations must have a corresponding declaration (from def()), otherwise they are invalid. If you plan to register multiple implementations, DO NOT provide a function implementation when you def() the operator.

@gshtras gshtras force-pushed the mawong/reconcile_merge branch from 646e40d to ee47dc3 Compare September 3, 2024 22:15
@mawong-amd
Copy link
Author

mawong-amd commented Sep 4, 2024

Since 05e67ab has cherry-picked most of the salient changes in this PR, it will be closed and a new one opened for any further changes.

@mawong-amd
Copy link
Author

Overall good additions and a few bugs caught, thanks A couple of suggestions though, plus, let's move on to rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 as the default base, and leave only building rccl (from the rocm-6.2.0 tag) and triton, as it is buggy there lol

Is there a particular reason we want rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0? rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_staging works pretty well.

@mawong-amd mawong-amd closed this Sep 4, 2024
@mawong-amd mawong-amd force-pushed the mawong/reconcile_merge branch from 82a8e65 to 7fd46eb Compare September 4, 2024 13:45
@mawong-amd mawong-amd changed the title Reconcile merge differences [fix Custom All Reduce & Triton FA; remove Torchrun & Cython] Reconcile merge differences [fix Custom All Reduce; remove Torchrun & Cython] Sep 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants