From cbcd07a09bb39632cbf2d263bf2a4119302e7eb2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 10 May 2024 21:32:29 +0000 Subject: [PATCH] ruff --- vllm/_custom_ops.py | 320 ++++++++++++++++++++++++++++++-------------- 1 file changed, 223 insertions(+), 97 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b25151f0fca6a..d7d4f35e33bf8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -12,7 +12,6 @@ pass - # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: vllm_ops.silu_and_mul(out, x) @@ -50,10 +49,21 @@ def paged_attention_v1( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, kv_scale) + vllm_ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) def paged_attention_v2( @@ -74,11 +84,24 @@ def paged_attention_v2( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, - max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale) + vllm_ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) # pos encoding ops @@ -90,85 +113,136 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + vllm_ops.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) +def batched_rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor, +) -> None: + vllm_ops.batched_rotary_embedding( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox, + rot_dim, + cos_sin_cache_offsets, + ) # layer norm ops -def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: +def rms_norm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: vllm_ops.rms_norm(out, input, weight, epsilon) -def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: +def fused_add_rms_norm( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> None: vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops # awq -def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: - return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) - - -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: +def awq_dequantize( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: int, + thx: int, + thy: int, +) -> torch.Tensor: + return vllm_ops.awq_dequantize( + qweight, scales, zeros, split_k_iters, thx, thy + ) + + +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, - bit: int) -> torch.Tensor: - return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) - - -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, - bit: int) -> None: +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, +) -> torch.Tensor: + return vllm_ops.gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + ) + + +def gptq_shuffle( + q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int +) -> None: vllm_ops.gptq_shuffle(q_weight, q_perm, bit) # squeezellm -def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, - lookup_table: torch.Tensor) -> None: +def squeezellm_gemm( + vec: torch.Tensor, + mat: torch.Tensor, + mul: torch.Tensor, + lookup_table: torch.Tensor, +) -> None: vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin -def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return vllm_ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) # cutlass -def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor) -> torch.Tensor: - shape_fallback = (b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0) +def cutlass_scaled_mm_dq( + a: torch.Tensor, + b: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, +) -> torch.Tensor: + shape_fallback = b.shape[0] % 16 != 0 or b.shape[1] % 16 != 0 if capability < 80 or shape_fallback: a_bf16 = a.to(dtype=torch.bfloat16) b_bf16 = b.to(dtype=torch.bfloat16) - return (b_scales * - (a_scales * torch.mm(a_bf16, b_bf16))).to(dtype=torch.bfloat16) + return (b_scales * (a_scales * torch.mm(a_bf16, b_bf16))).to( + dtype=torch.bfloat16 + ) else: m = a.shape[0] @@ -181,35 +255,66 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, # aqlm -def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, - codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) - - -def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: torch.Tensor) -> torch.Tensor: +def aqlm_gemm( + input: torch.Tensor, + codes: torch.Tensor, + codebooks: torch.Tensor, + scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + return vllm_ops.aqlm_gemm( + input, codes, codebooks, scales, codebook_partition_sizes, bias + ) + + +def aqlm_dequant( + codes: torch.Tensor, + codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor, +) -> torch.Tensor: return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) # gptq_marlin -def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) - - -def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, size_k: int, - is_k_full: bool) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack( + b_q_weight, perm, size_k, size_n, num_bits + ) + + +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, +) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + g_idx, + perm, + workspace, + num_bits, + size_m, + size_n, + size_k, + is_k_full, + ) # fp8 @@ -227,13 +332,22 @@ def scaled_fp8_quant( # moe -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) +def moe_align_block_size( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + vllm_ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) def reshape_and_cache( @@ -245,8 +359,15 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + vllm_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + kv_scale, + ) def reshape_and_cache_flash( @@ -257,17 +378,22 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + vllm_cache_ops.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype + ) -def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, - block_mapping: torch.Tensor) -> None: +def copy_blocks( + key_caches: torch.Tensor, + value_caches: torch.Tensor, + block_mapping: torch.Tensor, +) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) -def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]) -> None: +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: Dict[int, int] +) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) @@ -275,4 +401,4 @@ def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: vllm_cache_ops.convert_fp8(output, input) -#TODO: cuda_utils, custom_ar +# TODO: cuda_utils, custom_ar