diff --git a/lmdeploy/pytorch/kernels/pagedattention.py b/lmdeploy/pytorch/kernels/pagedattention.py index 39fe92204b..99a92d769a 100644 --- a/lmdeploy/pytorch/kernels/pagedattention.py +++ b/lmdeploy/pytorch/kernels/pagedattention.py @@ -229,21 +229,27 @@ def _reduce_split_kernel( tl.store(Out + out_offs, acc) -_NV_CAP = torch.cuda.get_device_capability() -if _NV_CAP[0] >= 8: +def _get_convert_pv(nv_capability): + """lazy load convert_pv.""" + if nv_capability[0] >= 8: + + @triton.jit + def convert_pv(p, v): + """convert pv.""" + p = p.to(v.dtype) + return p, v + else: - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - p = p.to(v.dtype) - return p, v -else: + @triton.jit + def convert_pv(p, v): + """convert pv.""" + v = v.to(p.dtype) + return p, v - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - v = v.to(p.dtype) - return p, v + return convert_pv + + +_convert_pv = None @triton.jit @@ -408,6 +414,10 @@ def paged_attention_fwd( max_seqlen (int): The max input length. BLOCK (int): The kernel block size. """ + global _convert_pv + if _convert_pv is None: + nv_cap = torch.cuda.get_device_capability() + _convert_pv = _get_convert_pv(nv_cap) def _kernel_meta(): """kernel meta."""