From e58294ddf23107d93987c00611d63a20e3cfe771 Mon Sep 17 00:00:00 2001 From: JGSweets Date: Fri, 5 Jul 2024 12:41:01 -0500 Subject: [PATCH] [Bugfix] Add verbose error if scipy is missing for blocksparse attention (#5695) --- .../ops/blocksparse_attention/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py index 0d90dd971e156..b1808970d7939 100644 --- a/vllm/attention/ops/blocksparse_attention/utils.py +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -6,7 +6,14 @@ import torch import triton -from scipy import sparse + +try: + from scipy import sparse +except ImportError as err: + raise ImportError("Please install scipy via " + "`pip install scipy` to use " + "BlockSparseAttention in " + "models such as Phi-3.") from err def dense_to_crow_col(x: torch.Tensor): @@ -77,11 +84,11 @@ def _get_sparse_attn_mask_homo_head( ): """ :return: a tuple of 3: - - tuple of crow_indices, col_indices representation + - tuple of crow_indices, col_indices representation of CSR format. - block dense mask - - all token dense mask (be aware that it can be - OOM if it is too big) if `return_dense==True`, + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, otherwise, None """ with torch.no_grad(): @@ -148,10 +155,10 @@ def get_sparse_attn_mask( :param dense_mask_type: "binary" (0 for skip token, 1 for others) or "bias" (-inf for skip token, 0 or others) :return: a tuple of 3: - - tuple of crow_indices, col_indices representation + - tuple of crow_indices, col_indices representation of CSR format. - block dense mask - - all token dense mask (be aware that it can be OOM if it + - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None """ assert dense_mask_type in ("binary", "bias")