Skip to content

Commit

Permalink
lazy load convert_pv jit function (#1253)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Mar 11, 2024
1 parent 1fcd6d3 commit ae8eafb
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions lmdeploy/pytorch/kernels/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,21 +229,27 @@ def _reduce_split_kernel(
tl.store(Out + out_offs, acc)


_NV_CAP = torch.cuda.get_device_capability()
if _NV_CAP[0] >= 8:
def _get_convert_pv(nv_capability):
"""lazy load convert_pv."""
if nv_capability[0] >= 8:

@triton.jit
def convert_pv(p, v):
"""convert pv."""
p = p.to(v.dtype)
return p, v
else:

@triton.jit
def _convert_pv(p, v):
"""convert pv."""
p = p.to(v.dtype)
return p, v
else:
@triton.jit
def convert_pv(p, v):
"""convert pv."""
v = v.to(p.dtype)
return p, v

@triton.jit
def _convert_pv(p, v):
"""convert pv."""
v = v.to(p.dtype)
return p, v
return convert_pv


_convert_pv = None


@triton.jit
Expand Down Expand Up @@ -408,6 +414,10 @@ def paged_attention_fwd(
max_seqlen (int): The max input length.
BLOCK (int): The kernel block size.
"""
global _convert_pv
if _convert_pv is None:
nv_cap = torch.cuda.get_device_capability()
_convert_pv = _get_convert_pv(nv_cap)

def _kernel_meta():
"""kernel meta."""
Expand Down

0 comments on commit ae8eafb

Please sign in to comment.