diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..6bc4270f02 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -20,5 +20,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py +pytest -v -s $TE_PATH/tests/pytorch/test_generation.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 81c4973756..e903ecb6e5 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -31,30 +31,118 @@ def apply_rotary_pos_emb_thd( t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor, - cp_size: int = 1, + start_positions: torch.Tensor, + cp_size: int = 1, cp_rank: int = 0, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. - + Args: t (Tensor): Input tensor T is of shape [t, h, d] cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + start_positions (Tensor): Tensor of shape [b] determining the beginning offsets + of frequeuncies applied to sequences. Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return torch.cat( - [ - apply_rotary_pos_emb( - x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) - ) - for x in torch.split(t, seqlens) - ] - ).squeeze(1) + if start_positions is None: + return torch.cat( + [ + apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + else: + return torch.cat( + [ + apply_rotary_pos_emb( + x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) + ) + for x in torch.split(t, seqlens) + ] + ).squeeze(1) + + +def apply_rotary_pos_emb_with_start_positions( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + This is non-fused version which supports start_positions parameters. + Non-fused implementation with start_positions is slow, thus it is not included in the + Transformer Engine directly. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + start_positions: torch.Tensor, default = None. + We may not want begin all the sequences from the 0 embedding. + This tensor argument allows that. + """ + + def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + if start_positions is None: + return apply_rotary_pos_emb(t, freqs, tensor_format=tensor_format) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + + if tensor_format == "bshd": + t = t.transpose(0, 1) + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # shifted_sin, shifted_cos will have the same shape as t. They will contain + # scaling factors shifted for each sequence by the corresponding start_positions offset. + + shifted_sin = sin_[:cur_seq_len].expand(t.shape).clone() + shifted_cos = cos_[:cur_seq_len].expand(t.shape).clone() + + for b in range(start_positions.shape[0]): + assert max_seq_len >= start_positions[b] + shifted_freq = slice(start_positions[b], (start_positions[b] + cur_seq_len)) + shifted_sin[:, b, :] = sin_[shifted_freq, 0, ...] + shifted_cos[:, b, :] = cos_[shifted_freq, 0, ...] + + t = (t * shifted_cos) + (_rotate_half(t) * shifted_sin) + out = torch.cat((t, t_pass), dim=-1) + + if tensor_format == "bshd": + out = out.transpose(0, 1).contiguous() + + return out # Gradient is a broadcasted scalar @@ -73,8 +161,9 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) -@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("tensor_format", ["bshd", "sbhd"]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) def test_fused_rope( dtype: torch.dtype, @@ -82,6 +171,7 @@ def test_fused_rope( hidden_size: int, rotary_percent: float, margin: int, + start_positions: bool, transpose: Union[Tuple, None], tensor_format: str, loss_func: Callable, @@ -99,14 +189,24 @@ def test_fused_rope( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True + if margin == 0 and start_positions == True: + # If sequence to encode has the same length as length of encoding + # there is no space left for starting with positions >0. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(seq_length) - # unfused # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison output_unfused = apply_rotary_pos_emb( - t.float(), emb, tensor_format=tensor_format, fused=False + t.float(), emb, tensor_format=tensor_format, start_positions=start_positions, fused=False ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() @@ -115,10 +215,7 @@ def test_fused_rope( # fused output_fused = apply_rotary_pos_emb( - t, - emb, - tensor_format=tensor_format, - fused=True, + t, emb, tensor_format=tensor_format, fused=True, start_positions=start_positions ) loss_fused = loss_func(output_fused) loss_fused.backward() @@ -135,6 +232,7 @@ def test_fused_rope( @pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("cp_size", [1, 2, 3]) def test_fused_rope_thd( dtype: torch.dtype, @@ -142,6 +240,7 @@ def test_fused_rope_thd( rotary_percent: float, transpose: Union[Tuple, None], loss_func: Callable, + start_positions: bool, cp_size: int, ) -> None: device = torch.device("cuda:0") @@ -170,6 +269,12 @@ def test_fused_rope_thd( t = t.transpose(*transpose).contiguous().transpose(*transpose) t.requires_grad = True + start_positions = ( + torch.randint(0, 20, (cu_seqlens.shape[-1],), dtype=torch.int32, device=device) + if start_positions + else None + ) + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) emb = rotary_pos_emb(cu_seqlens_padded[-1]) @@ -178,7 +283,7 @@ def test_fused_rope_thd( # The fused kernel computes in float32 internally, so we force the unfused func to use float32 # for more accurate comparison output_unfused = apply_rotary_pos_emb_thd( - t.float(), cu_seqlens_padded, emb, cp_size, cp_rank + t.float(), cu_seqlens_padded, emb, start_positions, cp_size, cp_rank ).to(dtype) loss_unfused = loss_func(output_unfused) loss_unfused.backward() @@ -189,8 +294,9 @@ def test_fused_rope_thd( output_fused = apply_rotary_pos_emb( t, emb, - fused=True, tensor_format="thd", + fused=True, + start_positions=start_positions, cu_seqlens=cu_seqlens_padded, cp_size=cp_size, cp_rank=cp_rank, diff --git a/tests/pytorch/test_generation.py b/tests/pytorch/test_generation.py new file mode 100644 index 0000000000..343dd4db1d --- /dev/null +++ b/tests/pytorch/test_generation.py @@ -0,0 +1,210 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te + + +class TestInferenceParams: + def test_setup_before_new_input_bshd(self): + inference_params = te.attention.InferenceParams(64, 128, qkv_format="bshd") + + inference_params.setup_before_new_input(length=16) + # Offset before first sequence is equal to 0. + assert inference_params.sequence_len_offset == 0 + + # Offset before second sequence is equal to 16. + inference_params.setup_before_new_input(length=4) + assert inference_params.sequence_len_offset == 16 + + def test_setup_before_new_input_thd(self): + inference_params = te.attention.InferenceParams(4, 128, qkv_format="thd") + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=20 + ) + + assert torch.equal( + inference_params.cached_sequence_lengths, torch.Tensor([0, 0, 0, 0]).cuda() + ) + assert torch.equal( + inference_params.input_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda() + ) + assert inference_params.max_incoming_seq_len == 20 + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([2, 3, 5, 1]).cuda(), max_input_length=10 + ) + assert torch.equal( + inference_params.cached_sequence_lengths, torch.Tensor([1, 0, 2, 4]).cuda() + ) + assert torch.equal( + inference_params.input_sequence_lengths, torch.Tensor([2, 3, 5, 1]).cuda() + ) + assert inference_params.max_incoming_seq_len == 10 + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("batch_size", [64, 128, 256]) + @pytest.mark.parametrize("max_seq_len", [128, 256, 512]) + @pytest.mark.parametrize("max_input_len", [32, 128]) + def test_save_to_kv_cache_thd(self, batch_size, max_seq_len, max_input_len, dtype): + h, d = 16, 256 + + inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="thd") + inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype) + + t = batch_size * max_input_len + key_layer = torch.randn((t, h, d)).cuda().to(dtype) + value_layer = torch.randn((t, h, d)).cuda().to(dtype) + + sequence_lengths = [1, 2] * (batch_size // 2) + + # We save the same sequences two time, which should result in sequences of lentgh 2 and 4 + # in the cache + inference_params.reset() + inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len + ) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + inference_params.setup_before_new_input( + lengths_tensor=torch.tensor(sequence_lengths).cuda(), max_input_length=max_input_len + ) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + key_memory, value_memory = inference_params.key_value_memory_dict[1] + + # Chcek whether the sequences were copied properly. + + def check(memory, layer, b, idx1, idx2): + # Check if sequence idx in batch b in memory corresponds + # to the sequence idx2 in batch b in layer. + assert torch.equal(memory[b * max_seq_len + idx1], layer[b * max_input_len + idx2, :]) + + # even indices + for b in range(0, batch_size, 2): + check(key_memory, key_layer, b, 0, 0) + check(key_memory, key_layer, b, 1, 0) + assert (key_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all() + + check(value_memory, value_layer, b, 0, 0) + check(value_memory, value_layer, b, 1, 0) + assert (value_memory[b * max_seq_len + 2 : ((b + 1) * max_seq_len)] == 0).all() + + # odd indices + for b in range(1, batch_size, 2): + check(key_memory, key_layer, b, 0, 0) + check(key_memory, key_layer, b, 1, 1) + check(key_memory, key_layer, b, 2, 0) + check(key_memory, key_layer, b, 3, 1) + assert (key_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all() + + check(value_memory, value_layer, b, 0, 0) + check(value_memory, value_layer, b, 1, 1) + check(value_memory, value_layer, b, 2, 0) + check(value_memory, value_layer, b, 3, 1) + assert (value_memory[b * max_seq_len + 4 : ((b + 1) * max_seq_len)] == 0).all() + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("batch_size", [64, 128, 256]) + @pytest.mark.parametrize("max_seq_len", [128, 256, 512]) + def test_save_to_kv_cache_bshd(self, batch_size, max_seq_len, dtype): + # This test checks if key_layer and value_layer are copied to cache. + # Cache size is equal to the size of one key/value layer. + h, d = 16, 256 + + inference_params = te.attention.InferenceParams(batch_size, max_seq_len, qkv_format="bshd") + + inference_params.allocate_memory_for_kv_cache_if_empty(1, h, d, dtype) + key_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype) + value_layer = torch.randn((max_seq_len, batch_size, h, d)).cuda().to(dtype) + + inference_params.setup_before_new_input(length=0) + inference_params.save_to_kv_cache(1, key_layer, value_layer) + + key_memory, value_memory = inference_params.key_value_memory_dict[1] + + assert torch.equal(key_memory, key_layer) + assert torch.equal(value_memory, value_layer) + + @pytest.mark.parametrize("layer_number", [1, 100]) + @pytest.mark.parametrize("batch_size", [1, 128]) + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_allocate_memory_for_kv_cache_if_empty(self, layer_number, batch_size, dtype): + nr_heads = 16 + head_dim = 256 + max_sequence_len = 128 + inference_params = te.attention.InferenceParams( + batch_size, max_sequence_len, qkv_format="bshd" + ) + + assert layer_number not in inference_params.key_value_memory_dict + + inference_params.allocate_memory_for_kv_cache_if_empty( + layer_number, nr_heads, head_dim, dtype + ) + + key_memory, value_memory = inference_params.key_value_memory_dict[layer_number] + + assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + + # Should not allocate new buffers. + inference_params.allocate_memory_for_kv_cache_if_empty(layer_number, 100, 100, dtype) + + assert key_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + assert value_memory.shape == (max_sequence_len, batch_size, nr_heads, head_dim) + + def test_set_params_to_thd_attention(self): + # This test check whether parameteres needed to run thd attention + # are computed correcly. This parameters are passed to the fused_attn_fwd(..) + # to indicate which parts of the key/query/value layers are sequences and + # which of them are offsets. + batch_size = 4 + channels = 1024 + max_sequence_len = 128 + max_input_len = 20 + inference_params = te.attention.InferenceParams( + batch_size, max_sequence_len, qkv_format="thd" + ) + + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 1, 1, 1]).cuda(), max_input_length=max_input_len + ) + inference_params.setup_before_new_input( + lengths_tensor=torch.Tensor([1, 0, 2, 4]).cuda(), max_input_length=max_input_len + ) + + buffers = [torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)] + max_q_len, max_kv_len, buffers = inference_params.set_params_to_thd_attention( + buffers, channels + ) + + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = ( + buffers + ) + + assert max_q_len == max_input_len + assert max_kv_len == max_sequence_len + assert torch.equal(cu_seqlens_q, torch.tensor([0, 1, 1, 3, 7]).cuda()) + assert torch.equal(cu_seqlens_kv, torch.tensor([0, 2, 3, 6, 11]).cuda()) + + assert torch.equal( + seq_offsets_q, + torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_k, + torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_v, + torch.tensor([k * max_sequence_len * channels for k in range(batch_size + 1)]).cuda(), + ) + assert torch.equal( + seq_offsets_o, + torch.tensor([k * max_input_len * channels for k in range(batch_size + 1)]).cuda(), + ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c237dbaeb6..51d0d1a486 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import math +import functools import os -from typing import Dict, List, Optional +from typing import Dict, List, Tuple, Optional import pytest import copy import random @@ -13,6 +14,8 @@ import torch.nn as nn from torch.nn import Parameter +import transformer_engine.pytorch.cpp_extensions as ext + from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, @@ -45,6 +48,22 @@ sm_80plus = get_device_compute_capability() >= (8, 0) +@functools.cache +def _cudnn_version() -> Tuple[int, int, int]: + """Runtime cuDNN version (major, minor, patch)""" + encoded_version = ext.get_cudnn_version() + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) + minor, patch = divmod(encoded_version, 100) + return (major, minor, patch) + + +def get_device_compute_capability() -> Tuple[int, int]: + """CUDA compute capability of current GPU""" + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return (props.major, props.minor) + + seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -2034,6 +2053,139 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, assert_allclose(full_output, incremental_output, atol[dtype]) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model_key", model_configs_inference.keys()) +@pytest.mark.parametrize("use_RoPE", all_boolean) +@pytest.mark.parametrize("module", module_inference) +@pytest.mark.skipif( + get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+." +) +@pytest.mark.skipif(_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +def test_kv_cache_accuracy_thd(dtype, bs, model_key, use_RoPE, module): + """ + In thd attention sequences can have various lengths, + different that 's' dimension of input to the Transformer Layer. + + The test contains of: + - one context phase when sequences with various lengths(!) are passed through the model, + - 2 phases when sequences with length 1 are passed through the model. + + The output is compared with the case when all this sequences are passed at one. + """ + if dtype == torch.float32: + pytest.skip("torch.float32 does not support thd") + + fused_attn_env = os.environ["NVTE_FUSED_ATTN"] + os.environ["NVTE_FUSED_ATTN"] = "1" # Only fused attention supports thd. + + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs_inference[model_key] + + S = config.seq_len + B = bs + H = config.num_attention_heads + D = config.hidden_size + G = 2 # generation phase length + S_max = S + G + head_size = config.embed + + layer_number = 1 + rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") + + # Tensors have shapes [b, s, h, d] and the seqlens are the tensor of shapes [b] + # which indicate the length of sequences - sequences starts from the begining. + # This function copies sequences from tensor into dst_tensor. + # dst_tensor should be big enough to fit this sequences. + def _concat_thd(dst_tensor, dst_seqlens, tensor, seqlens): + for b in range(B): + dst_tensor[b, dst_seqlens[b] : (dst_seqlens[b] + seqlens[b]), :] = tensor[ + b, : seqlens[b], : + ] + dst_seqlens.copy_(dst_seqlens + seqlens) + + if module == "TransformerLayer": + model = TransformerLayer( + hidden_size=D, + ffn_hidden_size=4 * D, + num_attention_heads=H, + attn_input_format="thd", + self_attn_mask_type="padding_causal", + layer_number=layer_number, + params_dtype=dtype, + device="cuda", + ).eval() + attn_name = "self_attn_mask_type" + else: + model = ( + MultiheadAttention( + hidden_size=D, + num_attention_heads=H, + qkv_format="thd", + layer_number=layer_number, + params_dtype=dtype, + attn_mask_type="padding_causal", + ) + .cuda() + .eval() + ) + attn_name = "attn_mask_type" + + inference_params = InferenceParams(B, S_max, qkv_format="thd") + + kwargs = { + "inference_params": inference_params, + "rotary_pos_emb": rotary_freqs if use_RoPE else None, + } + + total_sequence_lengths = torch.zeros((B,)).cuda().to(torch.int32) + total_tensor = torch.zeros((B, S_max, D)).cuda().to(dtype) + + # Sequences split into chunks. + + # context phase + sequence_lengths = torch.randint(1, S, (B,)).cuda().to(torch.int32) + chunk = torch.randn((B, S, D)).cuda().to(dtype) + inference_params.setup_before_new_input(max_input_length=S, lengths_tensor=sequence_lengths) + model( + chunk, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None + ) + _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths) + + # generation phase + for _ in range(G): + sequence_lengths = torch.ones((B,)).cuda().to(torch.int32) + chunk = torch.randn((B, 1, D)).cuda().to(dtype) + inference_params.setup_before_new_input(max_input_length=1, lengths_tensor=sequence_lengths) + # we need to remove 'causal' from mask + # otherwise tokens we add will be considered as a first in the sequence, + # but they need to interact with all tokens from key-value cache. + # after removing this line, tests should fail + kwargs[attn_name] = "padding" + output = model(chunk, **kwargs) + _concat_thd(total_tensor, total_sequence_lengths, chunk, sequence_lengths) + incremental_logits = output[:, -1, :] # last element of each seq. + + # Sequences passed in one, concatenated chunk. + + kwargs[attn_name] = "padding_causal" # add 'causal' back to the mask + inference_params.reset() + inference_params.setup_before_new_input( + max_input_length=S_max, lengths_tensor=total_sequence_lengths + ) + full_output = model(total_tensor, **kwargs) + full_logits = full_output[ + torch.arange(0, B), total_sequence_lengths - 1, : + ] # last element of each seq. + + # Final result should be close. + torch.testing.assert_close(full_logits, incremental_logits, atol=1e-2, rtol=1e-2) + + os.environ["NVTE_FUSED_ATTN"] = fused_attn_env + + @pytest.mark.parametrize( "shape", [ diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 26f104d3ed..c083268414 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -87,36 +87,45 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq } template -__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, + const int *start_positions, scalar_t *dst, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + + s_id = s_id + begin_offset; fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, - const int h, const int d, const int d2, - const int stride_s, const int stride_b, - const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d) { +__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, + const int *start_positions, scalar_t *dst, const int h, + const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, + const int o_stride_b, const int o_stride_h, + const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; + + s_id = s_id + begin_offset; fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, + const float *freqs, const int *start_positions, + scalar_t *dst, const int cp_size, const int cp_rank, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, @@ -140,7 +149,8 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; } } else { - s_id_for_freqs = s_id; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; } fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); @@ -148,7 +158,8 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu template __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, - const float *freqs, scalar_t *dst, const int cp_size, + const float *freqs, const int *start_positions, + scalar_t *dst, const int cp_size, const int cp_rank, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, @@ -172,15 +183,17 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; } } else { - s_id_for_freqs = s_id; + int begin_offset = (start_positions == 0) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; } fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template -void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, - const int s, const int b, const int h, const int d, const int d2, +void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, + const int *start_positions, scalar_t *output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { @@ -189,31 +202,32 @@ void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scal dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, - o_stride_b, o_stride_h, o_stride_d); + input, freqs, start_positions, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, - scalar_t *input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const int *start_positions, scalar_t *input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, - o_stride_s, o_stride_b, o_stride_h, o_stride_d); + output_grads, freqs, start_positions, input_grads, h, d, d2, stride_s, stride_b, stride_h, + stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens, - const float *freqs, scalar_t *output, const int cp_size, + const float *freqs, const int *start_positions, + scalar_t *output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, @@ -224,14 +238,15 @@ void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlen dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_forward_kernel<<>>( - input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, + input, cu_seqlens, freqs, start_positions, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, const int cp_size, + const float *freqs, const int *start_positions, + scalar_t *input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, @@ -242,41 +257,45 @@ void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *c dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, + output_grads, cu_seqlens, freqs, start_positions, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } -void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s, - const int b, const int h, const int d, const int d2, const int stride_s, - const int stride_b, const int stride_h, const int stride_d, - const int o_stride_s, const int o_stride_b, const int o_stride_h, - const int o_stride_d, cudaStream_t stream) { +void fused_rope_forward(const Tensor &input, const Tensor &freqs, const Tensor &start_positions, + Tensor *output, const int s, const int b, const int h, const int d, + const int d2, const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, scalar_t, fused_rope_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads, - const int s, const int b, const int h, const int d, const int d2, - const int stride_s, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s, const int o_stride_b, - const int o_stride_h, const int o_stride_d, cudaStream_t stream) { +void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, + const Tensor &start_positions, Tensor *input_grads, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, + const int stride_b, const int stride_h, const int stride_d, + const int o_stride_s, const int o_stride_b, const int o_stride_h, + const int o_stride_d, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *output, const int cp_size, const int cp_rank, const int max_s, + const Tensor &start_positions, Tensor *output, const int cp_size, + const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { @@ -285,13 +304,15 @@ void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const fused_rope_thd_forward_launcher(reinterpret_cast(input.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(output->data.dptr), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); } void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *input_grads, const int cp_size, + const Tensor &freqs, const Tensor &start_positions, + Tensor *input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, @@ -301,6 +322,7 @@ void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlen fused_rope_thd_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); @@ -308,35 +330,39 @@ void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlen } // end namespace transformer_engine -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, + const NVTETensor start_positions, NVTETensor output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_forward); using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), - *reinterpret_cast(freqs), reinterpret_cast(output), - s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, stream); + *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), + reinterpret_cast(output), s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream) { + const NVTETensor start_positions, NVTETensor input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, @@ -346,12 +372,14 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq fused_rope_thd_forward(*reinterpret_cast(input), *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + *reinterpret_cast(start_positions), reinterpret_cast(output), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, @@ -361,6 +389,7 @@ void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETenso fused_rope_thd_backward( *reinterpret_cast(output_grads), *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), - reinterpret_cast(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + *reinterpret_cast(start_positions), reinterpret_cast(input_grads), + cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b7b9b93881..5356776e1f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -17,6 +17,7 @@ extern "C" { * * \param[in] input Input tensor for fused rope. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets. * \param[out] output Output tensor. * \param[in] s Length of the s dimension of input. * \param[in] b Length of the b dimension of input. @@ -33,8 +34,9 @@ extern "C" { * \param[in] o_stride_d Stride of the d dimension of output. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, - const int s, const int b, const int h, const int d, const int d2, +void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, + const NVTETensor start_positions, NVTETensor output, const int s, + const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d, cudaStream_t stream); @@ -43,6 +45,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT * * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The tensor with positions of first tokens in sequences. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] s Length of the s dimension of output_grads. * \param[in] b Length of the b dimension of output_grads. @@ -60,17 +63,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVT * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, - NVTETensor input_grads, const int s, const int b, const int h, - const int d, const int d2, const int stride_s, const int stride_b, - const int stride_h, const int stride_d, const int o_stride_s, - const int o_stride_b, const int o_stride_h, const int o_stride_d, - cudaStream_t stream); + const NVTETensor start_positions, NVTETensor input_grads, const int s, + const int b, const int h, const int d, const int d2, + const int stride_s, const int stride_b, const int stride_h, + const int stride_d, const int o_stride_s, const int o_stride_b, + const int o_stride_h, const int o_stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the input tensor in thd format. * * \param[in] input Input tensor for fused rope. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The tensor with positions of first tokens in sequences. * \param[out] output Output tensor. * \param[in] cp_size Context parallel world size. * \param[in] cp_rank Context parallel rank. @@ -88,7 +92,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, const int cp_size, + const NVTETensor freqs, NVTETensor start_positions, + NVTETensor output, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, @@ -99,6 +104,7 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets. * \param[out] input_grads Input gradient to calculate. * \param[in] cp_size Context parallel world size. * \param[in] cp_rank Context parallel rank. @@ -116,7 +122,8 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int cp_size, + const NVTETensor freqs, NVTETensor start_positions, + NVTETensor input_grads, const int cp_size, const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..2eb6a40b54 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -984,18 +984,43 @@ class InferenceParams: # pylint: disable=too-few-public-methods Parameters ---------- - max_batch_size : int + max_batch_size: int maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. + max_sequence_length: int + maximum sequence length during inference. + qkv_format: str + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. + `s` stands for the sequence length dimension, + `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of sequences in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. """ - def __init__(self, max_batch_size, max_sequence_length): + def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): + assert qkv_format in ["bshd", "sbhd", "thd"] + self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 + + # self.key_value_memory_dict[layer number] = (key_cache, value_cache) + # if qkv_format in ["bshd", "sbhd"]: (key/value)_cache.shape = [b/s, s/b, h, d] + # # if qkv_format = "thd": (key/value)_cache.shape = [t, h, d] self.key_value_memory_dict = {} + self.qkv_format = qkv_format + + if qkv_format == "thd": + # In thd attention layout input sequences can have different lenghts. + # self.input_sequence_lengths stores tensor of shape [b] with lengths of input sequences + # and self.cached_sequence_lengths is the sum of all previous input lengths tensors - + # equivalently it contains total lengths of cached sequences. + self.cached_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32) + self.input_sequence_lengths = torch.zeros( + (max_batch_size,), device="cuda", dtype=torch.int32) + else: + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.input_sequence_length = None def swap_key_value_dict(self, batch_indices): """ @@ -1023,6 +1048,224 @@ def swap_key_value_dict(self, batch_indices): ) + def setup_before_new_input(self, lengths_tensor=None, max_input_length=None, length=None): + """ + Updates parameters representing incoming sequence lengths and lengths + of sequences in the cache. Should be called before every forward pass in the inference. + + Parameters + ---------- + lengths_tensor: torch.Tensor + 1d tensor with sequence lengths in new input. + Should be used only when self.qkv_format = "thd". + max_input_length: int + Should be used only when self.qkv_format = "thd". + If the incoming sequences tensor has shape [b * s, h, d], + this should be equal to s. + length: int + Length of the incoming sequences. + Should be used only when self.qkv_format in ["bshd", "sbhd"]. + """ + if self.qkv_format == "thd": + assert lengths_tensor is not None and max_input_length is not None, \ + "lengths_tensor and max_input_length should not be none for qkv_format = \"thd\"" + torch.add( + self.cached_sequence_lengths, + self.input_sequence_lengths, + out=self.cached_sequence_lengths) + self.input_sequence_lengths.copy_(lengths_tensor) + self.max_incoming_seq_len = max_input_length + + else: + assert length is not None, \ + "length should not be none for qkv_format in [\"bshd\", \"sbhd\"]" + if self.input_sequence_length is not None: + self.sequence_len_offset += self.input_sequence_length + self.input_sequence_length = length + + def reset(self): + """ + Resets the parameters to allow the use of this object in a new generation iteration. + This method does not reallocate buffers, + making it more efficient than creating a new InferenceParams object. + Moreover, reusing the same object with the same buffers is compatible + with the CUDA Graphs. + """ + if self.qkv_format == "thd": + self.cached_sequence_lengths.zero_() + self.input_sequence_lengths.zero_() + else: + self.input_sequence_length = None + self.sequence_len_offset = 0 + + def save_to_kv_cache(self, layer_number, key_layer, value_layer): + """ + Saves key_layer and value_layer in the cache. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + key_layer: torch.Tensor + Tensor - of the format corresponding to the self.qkv_format - + representing key_layer. + Notice: if self.qkv_format in ["bshd", "sbhd"] then both layers are in format sbhd + Notice: if self.qkv_format = "thd", we assume that offsets of the sequences + are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1. + value_layer: int + Tensor - of the format corresponding to the self.qkv_format - + representing value_layer. + Notice: if self.qkv_format in ["bshd", "sbhd"] both layers are in format sbhd + Notice: if self.qkv_format = "thd", we assume that offsets of the sequences + are of the form k * self.max_incoming_seq_len for k = 0, ..., batch_size-1. + """ + # Current kernels work only with contiguous tensors, it can be made faster in the future. + key_layer, value_layer = key_layer.contiguous(), value_layer.contiguous() + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + if self.qkv_format == "thd": + channels = inference_key_memory.shape[1] * inference_key_memory.shape[2] # h * d + # This kernels copies kernels from input layers into cache, + # taking into account the thd format and sequence lengths. + tex.attention_copy( + inference_key_memory, + self.cached_sequence_lengths, + self.input_sequence_lengths, + key_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + self.max_batch_size, + channels) + + tex.attention_copy( + inference_value_memory, + self.cached_sequence_lengths, + self.input_sequence_lengths, + value_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + self.max_batch_size, + channels) + key_layer, value_layer = inference_key_memory, inference_value_memory + else: + assert self.qkv_format in ["bshd", "sbhd"], \ + "Attention format not supported by the inference." + batch_start = self.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = self.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + seq_offsets = slice(sequence_start, sequence_end) + batch_offsets = slice(batch_start, batch_end) + inference_key_memory[seq_offsets, batch_offsets, ...] = key_layer + inference_value_memory[seq_offsets, batch_offsets, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_offsets, ...] + value_layer = inference_value_memory[:sequence_end, batch_offsets, ...] + return key_layer, value_layer + + def allocate_memory_for_kv_cache_if_empty( + self, + layer_number, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype): + """ + Allocates memory for kv_cache for given layer, if it hasn't been alocated before. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + num_gqa_groups_per_partition: torch.Tensor + This will be third dimension of cache tensor. + hidden_size_per_attention_head: int + This will be fourth dimension of cache tensor. + """ + + if layer_number in self.key_value_memory_dict: + return # Already allocated + + b, s = self.max_batch_size, self.max_sequence_length + + def _allocate_memory(dims): + return torch.zeros( + *dims, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + # def _allocate_memory( + # self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype + # ) -> torch.Tensor: + # return torch.empty( + # inference_max_sequence_len, + # batch_size, + # self.num_gqa_groups_per_partition, + # self.hidden_size_per_attention_head, + # dtype=dtype, + # device=torch.cuda.current_device(), + # ) + + if self.qkv_format == "thd": + inference_key_memory = _allocate_memory((b * s,)) + inference_value_memory = _allocate_memory((b * s,)) + else: + inference_key_memory = _allocate_memory((s, b)) + inference_value_memory = _allocate_memory((s, b)) + self.key_value_memory_dict[layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + + def set_params_to_thd_attention(self, buffers): + """ + Fused attention with q/k/v of thd layout with offsets needs some parameters informing + about sequence lengths. This function computes them and + saves them into the provided buffers. + + Parameters + ---------- + buffers: List[torch.Tensor] + buffers of size [batch_size + 1] for the parameters: + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, + seq_offsets_k, seq_offsets_v, seq_offsets_o + respectively. + channels: int + value of num_heads * hidden_dim_for_each_head. + + Returns + ---------- + max_seqlen_q: int + Maximal value of query sequence length. + max_seqlen_kv: int + Maximal value of key/value sequence length. + buffers: torch.Tensor + Tensor with filled buffers. + """ + max_seqlen_q, max_seqlen_kv = self.max_incoming_seq_len, self.max_sequence_length + + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k = buffers + + torch.cumsum(self.input_sequence_lengths, dim=0, out=cu_seqlens_q[1:]) + torch.cumsum( + self.cached_sequence_lengths + self.input_sequence_lengths, + dim=0, out=cu_seqlens_kv[1:]) + # If layer has shape [b * s_layer, h, d] + # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size] + seq_offsets_q.copy_( + torch.arange(0, self.max_batch_size + 1, device="cuda") * max_seqlen_q) + seq_offsets_k.copy_( + torch.arange(0, self.max_batch_size + 1, device="cuda") * max_seqlen_kv) + + return max_seqlen_q, max_seqlen_kv, buffers + @torch.no_grad() def get_swa_mask( window_size: Tuple[int, int], @@ -4465,22 +4708,32 @@ def forward( t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, cp_rank: int = 0, ) -> torch.Tensor: # pylint: disable=missing-function-docstring + + if start_positions is None: + # Each sequence will start from positional encoding corresponding to 0. + # Otherwise sequence i will start from positional encoding + # corresponding to start_positions[i]. + start_positions = torch.Tensor() + if freqs.dtype != torch.float32: freqs = freqs.float() if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) + output = tex.fused_rope_forward(t, freqs, start_positions, False) elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) + output = tex.fused_rope_forward( + t.transpose(0, 1), freqs, start_positions, True + ).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, start_positions, cp_size, cp_rank) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank @@ -4488,18 +4741,19 @@ def forward( return output @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - freqs, cu_seqlens = ctx.saved_tensors + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + freqs, cu_seqlens, start_positions = ctx.saved_tensors if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) + grad_input = tex.fused_rope_backward(grad_output, freqs, start_positions, False) elif ctx.tensor_format == "bshd": grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True + grad_output.transpose(0, 1), freqs, start_positions, True ).transpose(0, 1) elif ctx.tensor_format == "thd": grad_input = tex.fused_rope_thd_backward( - grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank + grad_output, cu_seqlens, freqs, start_positions, ctx.cp_size, ctx.cp_rank ) else: raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") @@ -4521,6 +4775,7 @@ def apply_rotary_pos_emb( freqs: torch.Tensor, tensor_format: str = "sbhd", fused: bool = False, + start_positions: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, cp_rank: int = 0, @@ -4544,24 +4799,41 @@ def apply_rotary_pos_emb( cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. - Should be `cu_seqlens_padded` when cp_size > 1. cp_size: int, default = 1. Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. + start_positions: torch.Tensor, default = None. + Token i from sequence s have position encoding corresponding to + position start_positions[i]. If start_positions=None, then this token has position i. + Should be `cu_seqlens_padded` when cp_size > 1. """ + assert not (start_positions is not None and not fused), \ + """start_positions != None and fused=False is not supported""" + if fused: assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) + # Fused RoPE expects the RoPE embedding tensor to be of shape "s 1 1 d" + if freqs.shape[1] == 1: + return FusedRoPEFunc.apply(t, freqs, tensor_format, start_positions, cu_seqlens, cp_size, cp_rank) + assert tensor_format in ("sbhd", "bshd"), ( "Only formats `sbhd` or `bshd` are supported for input tensor `t` " f"when fused is False, got {tensor_format}." ) - max_seq_len = freqs.shape[0] + # RoPE embeddings provided are of the form `s 1 1 d`. This is also the + # default tensor shape in `RotaryPositionEmbedding` above. + if freqs.shape[1] == 1: + max_seq_len = freqs.shape[0] + # RoPE embeddings are of the form `b s 1 d` and for now correspond to + # `arbitrary` attention mask. + else: + max_seq_len = freqs.shape[1] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] # Only apply the rotary embeddings up to the sequence length of the running @@ -4569,9 +4841,23 @@ def apply_rotary_pos_emb( assert ( cur_seq_len <= max_seq_len ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - if tensor_format == "bshd": - freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + + # Slice the RoPE embeddings in case they aren't already and transpose + # if `bshd` format is being used. + if freqs.shape[1] == 1: + freqs = freqs[:cur_seq_len, ...] if tensor_format == "sbhd" else freqs[:cur_seq_len, ...].transpose(0, 1) + else: + # This is the case when the `freqs` embedding has the shape `bs1d` which + # means that every sequence in the batch could have a different sequence + # length and has padding to ensure the overall sequence length dimension + # in both the embedding `freqs` and the target tensor `t` are the same. + assert ( + cur_seq_len == max_seq_len + ), f"Rope embeddings are of shape {freqs.shape} while target tensor is \ + of shape {t.shape}. Since each sequence could potentially have different \ + lengths (albeit padded), make sure the provided rope embeddings \ + sequence dimension matches the target tensor sequence dimension." + # cos/sin first then dtype conversion for better precision cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) @@ -7442,6 +7728,8 @@ def __init__( self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + # self.channels = kv_channels * num_attention_heads + # self.hidden_size_per_attention_head = kv_channels self.cp_comm_type = cp_comm_type self.hidden_size_per_attention_head_k = ( @@ -7542,6 +7830,15 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + self._allocator = StaticBufferAllocator() + + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): @@ -7832,6 +8129,7 @@ def forward( first microbatch (since it is the first gradient being produced) """ + batch_size = key_layer.shape[0] with self.prepare_forward( query_layer, is_first_microbatch, @@ -7909,37 +8207,41 @@ def forward( if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" + # @sudhakars: Distinguish the case when both context and gen + # phase have `thd_thd_thd` layout. # convert causal to causal_bottom_right in inference when KV-caching is in use # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: - attn_mask_type = attn_mask_type + "_bottom_right" + # if attn_mask_type in ["causal", "padding_causal"]: + # attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + key_layer, value_layer = inference_params.save_to_kv_cache( + self.layer_number, key_layer, value_layer + ) - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) + if qkv_format == "thd": + # Allocation of buffers, it works correctly with CUDA Graphs. + NR_BUFFERS = 4 + buffers = [ + self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + for _ in range(NR_BUFFERS) + ] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) + max_seqlen_q, max_seqlen_kv, buffers = \ + inference_params.set_params_to_thd_attention(buffers) + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, \ + seq_offsets_k = buffers - # Copy keys and values into KV-cache - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - key_layer - ) - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - value_layer - ) - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + # @sudhakars: remove this hardcoded value + cu_seqlens_q_padded, cu_seqlens_kv_padded = seq_offsets_q, seq_offsets_k + + # query_layer is reshaped to the format [t, h, d] + # and make contiguous - needed by the THD attention + query_layer = query_layer.view(-1, *query_layer.shape[2:]).contiguous() if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) @@ -8103,14 +8405,26 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" - pad_between_seqs = ( - cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) - ) or ( - cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) - ) + # @sudhakars: if using Flash Attention, need to check the `cu_seqlens_padded` + # and `cu_seqlens` + pad_between_seqs = True + # @sudhakars: this condition isn't compatible with CUDA Graphs capture. + # pad_between_seqs = ( + # cu_seqlens_q_padded is not None + # and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) + # ) or ( + # cu_seqlens_kv_padded is not None + # and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) + # ) + + # Check whether this is needed (@sudhakars27) + if self.attention_type == "self": + if self.qkv_format == "bshd" and query_layer.shape[1] != value_layer.shape[1] or \ + self.qkv_format == "sbhd" and query_layer.shape[0] != value_layer.shape[0]: + # Flash attention does not self-support max_seqlen_q != max_seqlen_kv + use_flash_attention = False + attention_params = AttentionParams( qkv_type=type(query_layer), qkv_dtype=query_layer.dtype, @@ -8666,6 +8980,14 @@ def __init__( **common_gemm_kwargs, ) + self._allocator = StaticBufferAllocator() + + def alloc(self, size, dtype, device): + """ + Allocated the buffer and works correctly with CUDA Graphs. + """ + return self._allocator(size, dtype, device) + def _allocate_memory( self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype ) -> torch.Tensor: @@ -8867,21 +9189,12 @@ def forward( # ================================================= if inference_params and self.layer_number is not None: - assert ( - self.qkv_format != "thd" - ), "qkv_format == thd is not supported for an inference with KV-cache!" if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, + inference_params.allocate_memory_for_kv_cache_if_empty( + self.layer_number, + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + hidden_states.dtype ) else: ( @@ -8959,7 +9272,8 @@ def forward( dim=split_dim, ) - if self.qkv_format == "thd": + # @sudhakars: fix this to `self.qkv_format == "thd"` later + if len(query_layer.shape) == 4: query_layer, key_layer, value_layer = ( x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) for x in (query_layer, key_layer, value_layer) @@ -9060,39 +9374,74 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) - else: - raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") - - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + sequence_length - - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - - query_layer = apply_rotary_pos_emb( - query_layer, - q_pos_emb, - self.qkv_format, - fused=True, - cu_seqlens=cu_seqlens_q, - cp_size=self.cp_size, - cp_rank=self.cp_rank, - ) - key_layer = apply_rotary_pos_emb( - key_layer, - k_pos_emb, - self.qkv_format, - fused=True, - cu_seqlens=cu_seqlens_kv, - cp_size=self.cp_size, - cp_rank=self.cp_rank, - ) + if self.qkv_format == "thd" and inference_params is not None: + # For thd attention incoming tokens can be on different positions, + # so we need to copy different positional encoding freqency + # for every sequence in a batch. + # + # For example if sequence lengths in context phase are: 2 and 5 (batch size=2), + # in first generation phase key_layer have shape [2, 1, d]. + # key_layer[0, :] corresponds to the token with position 3 = 2 + 1, + # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. + + query_layer = apply_rotary_pos_emb( + query_layer, + q_pos_emb, + "bshd", + fused=True, + start_positions=inference_params.cached_sequence_lengths, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + key_layer = apply_rotary_pos_emb( + key_layer, + k_pos_emb, + "bshd", + fused=True, + start_positions=inference_params.cached_sequence_lengths, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + else: + # adjust key and value for inference + if inference_params is not None: + if self.qkv_format == "sbhd": + sequence_length = key_layer.size(0) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length + + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + + elif self.qkv_format == "bshd": + sequence_length = key_layer.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length + + q_pos_emb = q_pos_emb[:, sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[:, sequence_start:sequence_end, ...] + + query_layer = apply_rotary_pos_emb( + query_layer, + q_pos_emb, + self.qkv_format, + fused = False if q_pos_emb.shape[1] > 1 else True, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + key_layer = apply_rotary_pos_emb( + key_layer, + k_pos_emb, + self.qkv_format, + fused = False if k_pos_emb.shape[1] > 1 else True, + cu_seqlens=cu_seqlens_kv, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + # =========================== # Core attention computation @@ -9118,6 +9467,12 @@ def forward( inference_params=inference_params, ) + if self.qkv_format == "thd": + # [b * sq, h] -> [qs, b, h] + context_layer = context_layer.view( + (inference_params.max_batch_size, -1, context_layer.shape[1]) + ).contiguous() + # =================== # Output. [sq, b, h] # =================== @@ -9138,3 +9493,20 @@ def forward( if self.input_layernorm and self.return_layernorm_output: outputs += (layernorm_output,) return outputs if len(outputs) > 1 else outputs[0] + + +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ + + # pylint: disable=no-self-use + def forward(self, size, dtype, device): + """ + Return buffer of given size, dtype and device. + """ + return torch.zeros(size, dtype=dtype, device=device) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..970287a975 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -408,16 +408,18 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio **************************************************************************************************/ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory); at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); + const at::Tensor &freqs, const at::Tensor &start_positions, const int cp_size, const int cp_rank); at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank); + const at::Tensor &freqs, const at::Tensor &start_positions, const int cp_size, const int cp_rank); /*************************************************************************************************** * Miscellaneous @@ -427,6 +429,17 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); +bool userbuf_comm_available(); + +void placeholder(); + +/*************************************************************************************************** + * Generation + **************************************************************************************************/ + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s); + /*************************************************************************************************** * Support THD format for Context Parallel **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index c0cd2e9920..225b9f3fc4 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -7,6 +7,7 @@ #include "extensions.h" at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); @@ -55,16 +56,19 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto input_cu = makeTransformerEngineTensor(input); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); auto output_cu = makeTransformerEngineTensor(output); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, - stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), start_positions_cu.data(), + output_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, + o_stride_s, o_stride_b, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return output; } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &start_positions, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); @@ -111,17 +115,20 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); auto input_grads_cu = makeTransformerEngineTensor(input_grads); - nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, - d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, - o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); + nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), start_positions_cu.data(), + input_grads_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, + stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, + at::cuda::getCurrentCUDAStream()); return input_grads; } at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { + const at::Tensor &freqs, const at::Tensor &start_positions, + const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -163,9 +170,10 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, + start_positions_cu.data(), output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); @@ -173,7 +181,8 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ } at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs, const int cp_size, const int cp_rank) { + const at::Tensor &freqs, const at::Tensor &start_positions, + const int cp_size, const int cp_rank) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -213,9 +222,10 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = makeTransformerEngineTensor(start_positions); nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, + start_positions_cu.data(), input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/generation.cu b/transformer_engine/pytorch/csrc/extensions/generation.cu new file mode 100644 index 0000000000..5a162f1af6 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/generation.cu @@ -0,0 +1,55 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" + +// Kernel used to update KV chache when attention layout is "thd". +template +__global__ void attention_copy_kernel(scalar_t* cache_tensor, int* seq_len, int* incoming_seq_len, + scalar_t* hidden_tensor, int max_incoming_seq_len, + int max_seq_len, int b, int s) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int to_copy = s * incoming_seq_len[batch_idx]; + int offset = seq_len[batch_idx]; + + scalar_t* begin_cache_copy = cache_tensor + max_seq_len * s * batch_idx + s * offset; + scalar_t* begin_hidden_copy = hidden_tensor + s * batch_idx * max_incoming_seq_len; + + for (int i = threadIdx.x; i < to_copy; i += blockDim.x) { + *(begin_cache_copy + i) = *(begin_hidden_copy + i); + } + } +} + +template +void attention_copy_launcher(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, + int s) { + attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(A.data_ptr()), seq_len.data_ptr(), + incoming_seq_len.data_ptr(), reinterpret_cast(B.data_ptr()), + max_incoming_seq_len, max_seq_len, b, s); +} + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, + torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) { + if (A.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + + } else if (A.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + } else if (A.scalar_type() == at::ScalarType::Float) { + using dtype = float; + attention_copy_launcher(A, seq_len, incoming_seq_len, B, max_incoming_seq_len, + max_seq_len, b, s); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 39679ed669..ba190a1154 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -174,6 +174,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams); + // Generation + m.def("attention_copy", &attention_copy, "attention_copy"); + // Support THD format for Context Parallel m.def("thd_read_half_tensor", &thd_read_half_tensor, "Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD " diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..438e88ef9c 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -184,6 +184,10 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + Notion: The experimental version of the 'thd' attention is supported + when :attr:`inference_params` is passed to the forward function. + + Parallelism parameters ---------------------- @@ -280,6 +284,9 @@ def __init__( ) -> None: super().__init__() + if ub_tp_comm_overlap: + assert tex.userbuf_comm_available(), "Userbuffer communication backend not available." + self.self_attn_mask_type = self_attn_mask_type self.window_size = check_set_window_size(self_attn_mask_type, window_size) self.enc_dec_attn_mask_type = enc_dec_attn_mask_type @@ -710,6 +717,7 @@ def forward( attention_output, attention_bias, hidden_states, self.drop_path ) + # Cross attention. if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( @@ -753,6 +761,7 @@ def forward( if self.output_layernorm: output = self.layernorm(output) + # output: [s, b, h] return output