Skip to content

Commit

Permalink
Update Dockerfile to 6.2, update ROCm components, remove Cython (#166)
Browse files Browse the repository at this point in the history
* Miscellaneous changes, Dockerfile components update, remove Cython

* Restore Dockerfile and Cython for now
  • Loading branch information
mawong-amd authored Sep 4, 2024
1 parent 7fd46eb commit 7edb2fd
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
1 change: 0 additions & 1 deletion csrc/custom/custom.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include "core/registration.h"

// declare templates for front (cpp) and back (cuda) sides of function:
// template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)


Expand Down
8 changes: 5 additions & 3 deletions vllm/entrypoints/sync_openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ async def _check_model(request: Union[CompletionRequest,

async def _guided_decode_logits_processor(request, tokenizer):
decoding_config = runner.engine_config.decoding_config
assert decoding_config is not None
guided_decoding_backend = (request.guided_decoding_backend
or decoding_config.guided_decoding_backend)
if request.guided_decoding_backend:
guided_decoding_backend = request.guided_decoding_backend
else:
assert decoding_config is not None
guided_decoding_backend = decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer)

Expand Down
9 changes: 4 additions & 5 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_nvml_context(fn):
def with_amdsmi_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -65,12 +65,11 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)

@staticmethod
@with_nvml_context
@with_amdsmi_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
# On ROCm, we instead query if GPUs are connected by 1 hop XGMI
handles = [
amdsmi_get_processor_handles()[i] for i in physical_device_ids
]
Expand All @@ -90,7 +89,7 @@ def is_full_nvlink(physical_device_ids: List[int]) -> bool:
return True

@staticmethod
@with_nvml_context
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
Expand Down

0 comments on commit 7edb2fd

Please sign in to comment.