-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llama] Store KV Cache on CPU and Use PyTorch SPDA
for Next token generation
#1182
base: main
Are you sure you want to change the base?
Changes from all commits
bd65f6a
8589571
10ae35e
7430ba4
1d617f1
36bc93b
d8e3431
b2e2f48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -349,7 +349,36 @@ def gaudi_llama_repeat_kv( | |
return query_states, key_states, value_states, attention_mask | ||
|
||
|
||
# FusedScaledDotProductAttention | ||
def gaudi_llama_repeat_kv_cpu( | ||
query_states: torch.Tensor, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
n_rep: int, | ||
): | ||
""" | ||
PyTorch SDPA CPU (flash-atten) kernel does not support GQA/MQA for now. | ||
So, expand k and v to num_query_heads. | ||
""" | ||
query_states = query_states.to("cpu") | ||
key_states = key_states.to("cpu") | ||
value_states = value_states.to("cpu") | ||
if attention_mask is not None: | ||
attention_mask = attention_mask.to("cpu") | ||
|
||
batch, num_key_value_heads, kv_len, head_dim = key_states.shape | ||
if n_rep == 1 or num_key_value_heads == 1: | ||
return query_states, key_states, value_states, attention_mask | ||
|
||
key_states = key_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, kv_len, head_dim) | ||
value_states = value_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, kv_len, head_dim) | ||
key_states = key_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) | ||
value_states = value_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim) | ||
|
||
return query_states, key_states, value_states, attention_mask | ||
|
||
|
||
# FusedScaledDotProductAttention | ||
class ModuleFusedSDPA(torch.nn.Module): | ||
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): | ||
super().__init__() | ||
|
@@ -453,9 +482,8 @@ def get_k_proj_weight_dtype(self): | |
return self.k_proj.scales.dtype | ||
return self.k_proj.weight.dtype | ||
|
||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): | ||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): | ||
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) | ||
device = self.get_k_proj_weight().device | ||
dtype = self.config.torch_dtype | ||
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) | ||
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) | ||
|
@@ -642,44 +670,37 @@ def pre_attn_forward( | |
kv_seq_len = key_states.shape[-2] | ||
else: | ||
past_key_value = None | ||
fused_scaled_dot_product_attention = GaudiDistributedAttention( | ||
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed | ||
) | ||
if use_flash_attention and FusedSDPA is not None: | ||
if q_len == 1: | ||
# next token | ||
attn_output = fused_scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attention_mask, | ||
0.0, | ||
False, | ||
None, | ||
"None", | ||
False, | ||
None, | ||
"None", | ||
) | ||
else: | ||
# first token | ||
softmax_mode = "fast" if flash_attention_fast_softmax else "None" | ||
if flash_attention_causal_mask: | ||
# causal masking on first token requires inputs to be of the same length | ||
attn_output = fused_scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
None, | ||
0.0, | ||
True, | ||
None, | ||
softmax_mode, | ||
flash_attention_recompute, | ||
valid_sequence_lengths, | ||
"left", | ||
) | ||
else: | ||
|
||
kv_cache_on_host = key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu") | ||
# CPU SDPA fot next token | ||
if kv_cache_on_host and q_len == 1 and not self.training: | ||
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu( | ||
query_states, key_states, value_states, attention_mask, self.num_key_value_groups | ||
) | ||
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | ||
# dispatch to flash attention implementation | ||
attn_output = F.scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_mask=attention_mask, | ||
dropout_p=0.0, | ||
is_causal=False, | ||
scale=self.norm_factor, | ||
) | ||
attn_output = attn_output.to("hpu", non_blocking=True) | ||
|
||
else: | ||
if kv_cache_on_host: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please explain what's the case switching kv_cache device? I thought line 656 is the case only when line 658. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this pr, we make kv cache store on cpu and do cpu sdpa only when generating the next token. The first token or prefill stage is performed on HPU due to its powerful computation ability under long-context scenario (long prompt in most cases). The full pipeline diagram shows on the pr description. |
||
key_states = key_states.to("hpu", non_blocking=True) | ||
value_states = value_states.to("hpu", non_blocking=True) | ||
|
||
fused_scaled_dot_product_attention = GaudiDistributedAttention( | ||
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed | ||
) | ||
if use_flash_attention and FusedSDPA is not None: | ||
if q_len == 1: | ||
# next token | ||
attn_output = fused_scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
|
@@ -688,35 +709,68 @@ def pre_attn_forward( | |
0.0, | ||
False, | ||
None, | ||
softmax_mode, | ||
flash_attention_recompute, | ||
"None", | ||
False, | ||
None, | ||
"None", | ||
) | ||
else: | ||
# first token | ||
softmax_mode = "fast" if flash_attention_fast_softmax else "None" | ||
if flash_attention_causal_mask: | ||
attn_output = fused_scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
None, | ||
0.0, | ||
True, | ||
None, | ||
softmax_mode, | ||
flash_attention_recompute, | ||
valid_sequence_lengths, | ||
"left", | ||
) | ||
else: | ||
attn_output = fused_scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attention_mask, | ||
0.0, | ||
False, | ||
None, | ||
softmax_mode, | ||
flash_attention_recompute, | ||
None, | ||
"None", | ||
) | ||
|
||
else: | ||
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( | ||
query_states, key_states, value_states, attention_mask, self.num_key_value_groups | ||
) | ||
else: | ||
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( | ||
query_states, key_states, value_states, attention_mask, self.num_key_value_groups | ||
) | ||
|
||
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor | ||
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor | ||
|
||
if attention_mask is not None: # no matter the length, we just slice it | ||
causal_mask = attention_mask | ||
if cache_position is not None: | ||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] | ||
attn_weights = attn_weights + causal_mask | ||
if attention_mask is not None: # no matter the length, we just slice it | ||
causal_mask = attention_mask | ||
if cache_position is not None: | ||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] | ||
attn_weights = attn_weights + causal_mask | ||
|
||
if attn_softmax_bf16: | ||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) | ||
else: | ||
# upcast attention to fp32 | ||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( | ||
query_states.dtype | ||
if attn_softmax_bf16: | ||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) | ||
else: | ||
# upcast attention to fp32 | ||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( | ||
query_states.dtype | ||
) | ||
attn_weights = torch.nn.functional.dropout( | ||
attn_weights, p=self.attention_dropout, training=self.training | ||
) | ||
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | ||
attn_output = self.matmul_av(attn_weights, value_states) | ||
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) | ||
attn_output = self.matmul_av(attn_weights, value_states) | ||
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) | ||
|
||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||
raise ValueError( | ||
|
@@ -862,8 +916,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): | |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): | ||
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) | ||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): | ||
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) | ||
|
||
def reorder_kv_cache(self, beam_idx: torch.LongTensor): | ||
return self.self_attn.reorder_kv_cache(beam_idx) | ||
|
@@ -1043,9 +1097,9 @@ def __init__(self, config: LlamaConfig): | |
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): | ||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): | ||
for layer in self.layers: | ||
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) | ||
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) | ||
|
||
def reorder_kv_cache(self, beam_idx: torch.LongTensor): | ||
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) | ||
|
@@ -1281,8 +1335,8 @@ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy | |
config.parallel_strategy = parallel_strategy | ||
super().__init__(config) | ||
|
||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): | ||
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) | ||
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"): | ||
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device) | ||
|
||
def reorder_kv_cache(self, beam_idx: torch.LongTensor): | ||
return self.model.reorder_kv_cache(beam_idx) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From line 1096 to 1107, I would like to suggest to change like this.
if not is_greedy_or_beam_and_bucket:
cache_device = "hpu"
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have updated it in 74e94ff. However, I can not remove the else line because I only modified the
modeling_llama.py
for this experimental feature.