diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ae88acf6b6501..d8ce0d8ed556a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -141,6 +141,10 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + assert query.stride(-1) == 1, "Query tensor must be contiguous." + assert key.stride(-1) == 1, "Key tensor must be contiguous." + assert value.stride(-1) == 1, "Value tensor must be contiguous." + output = torch.empty_like(query) torch.ops.vllm.unified_v1_flash_attention( output, @@ -210,23 +214,23 @@ def unified_v1_flash_attention( query[:num_actual_tokens], key_cache, value_cache, - None, + None, # out attn_metadata.query_start_loc, attn_metadata.seq_start_loc, - None, + None, # seqused_k attn_metadata.block_table, alibi_slopes, attn_metadata.max_query_len, attn_metadata.max_seq_len, - 0.0, + 0.0, # dropout_p softmax_scale, - False, - True, + False, # zero_tensors + True, # causal window_size[0], window_size[1], logits_soft_cap, - False, - None, + False, # return_softmax + None, # generator )[0] # TODO(woosuk): Remove this unnecessary copy. output[:num_actual_tokens].copy_(attn_output)