Skip to content

Commit

Permalink
[PyTorch] Improve get_qkv_layout (#1214)
Browse files Browse the repository at this point in the history
* improve get_attention_backend logic

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* polish logic and wording

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove redundant comment

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Oct 9, 2024
1 parent 2d87552 commit 5b6546c
Showing 1 changed file with 66 additions and 35 deletions.
101 changes: 66 additions & 35 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4799,74 +4799,105 @@ def get_qkv_layout(
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
"""

check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"

def run_iteratively(q, k, v):
# check data pointers
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

# check tensor shapes
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]

# check tensor strides
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
sv / v.shape[-1] for sv in v.stride()[:-1]
)

shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]
# check tensor offsets for h3d and 3hd layouts
prod_h_d = q.shape[-1] * q.shape[-2]
check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
check_h3d_offsets = all(
x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
)

last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
# check tensor offsets for hd_h2d and hd_2hd layouts
prod_all_dims = [np.prod(x.shape) for x in [q, k]]
offset = prod_all_dims[0] if check_ptrs_qkv else 0
prod_h_d = k.shape[-1] * k.shape[-2]
check_2hd_offsets = all(
x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
)
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
check_h2d_offsets = all(
x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
)

last_two_dims_size = q.shape[-1] * q.shape[-2]
check_last_two_dims_offsets_qkv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
# check tensor offsets for hd_hd_hd layouts
check_hd_offsets_qkv = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
if check_ptrs_qkv
else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
)
check_hd_offsets_qk = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
if not check_ptrs_qkv and check_ptrs_qk
else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
)
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
check_hd_offsets_kv = (
all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
if not check_ptrs_qkv and check_ptrs_kv
else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
)

if (
check_ptrs_qkv
and check_strides_qkv
and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv
):
if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
# sb3hd, bs3hd, t3hd
# one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
elif (
check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
):
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
# sbh3d, bsh3d, th3d
# one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
elif (
check_ptrs_kv
and check_strides_kv
and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv
):
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
# two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
# two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
elif (
check_strides_kv
and check_shapes_kv
and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
):
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = "not_supported"
Expand Down

0 comments on commit 5b6546c

Please sign in to comment.