Skip to content

Commit

Permalink
Merge branch 'develop' into navi_ck
Browse files Browse the repository at this point in the history
  • Loading branch information
hyoon1 authored Dec 16, 2024
2 parents f087706 + 22f9066 commit 64ce953
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ steps:
- exit_status: -10 # Agent was lost
limit: 5
agents:
queue: amd
queue: amd-cpu

{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
depends_on:
- "amd-build"
agents:
queue: amd
queue: amd_gpu
commands:
- bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}"
env:
Expand Down
9 changes: 4 additions & 5 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,8 @@ def check_and_convert(t, scale):
p_descale = 1.0 / p_scale
o_descale = 1.0 / o_scale

if is_navi():
max_seqlens_q = 0
max_seqlens_k = 0
arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q
arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k

attn_fwd[grid](
q,
Expand Down Expand Up @@ -944,8 +943,8 @@ def check_and_convert(t, scale):
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
MAX_SEQLENS_Q=arg_max_seqlens_q,
MAX_SEQLENS_K=arg_max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
Expand Down
2 changes: 2 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

Expand Down

0 comments on commit 64ce953

Please sign in to comment.