Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] Store KV Cache on CPU and Use PyTorch SPDA for Next token generation #1182

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,36 @@ ENABLE_EXPERIMENTAL_FLAGS=true python run_generation.py \
--load_quantized_model_with_autogptq
```

### Store KV Cache on CPU

Keeping key/value cache on CPU (host) side can decrease HPU VRAM usage at the price of slower generation latency. It's a practical solution in long context serving scenarios with a large LLM on a single card.

You can add the `--kv_cache_on_host` argument to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate the next token for saving data transfer time. The first token is not affected.

For exmaple:
```bash
python run_generation.py \
--model_name_or_path 01-ai/Yi-34B-Chat \
--use_kv_cache \
--bf16 \
--attn_softmax_bf16 \
--reuse_cache \
--do_sample \
--dataset_name emozilla/pg19-test \
--batch_size 1 \
--max_input_tokens 11200 \
--column_name "text" \
--dataset_max_samples 1 \
--warmup 0 \
--n_iterations 1 \
--max_new_tokens 5000 \
--kv_cache_on_host
```

> [!NOTE]
> 1. `--kv_cache_on_host` only supports Llama for now. It also does not work with `--use_hpu_graphs` and FP8 data type.
> 2. Try to use it only when you meet an out-of-memory error on the HPU device since it will increase latency.

## Language Model Evaluation Harness

The evaluation of LLMs can be done using the `lm_eval.py` script. It utilizes the [LM evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness)
Expand Down
11 changes: 11 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to reuse key/value cache for decoding. It should save memory.",
)
parser.add_argument(
"--kv_cache_on_host",
action="store_true",
help="Store key/value cache on CPU instead of HPU device (only supports Llama for now). It should save VRAM in long context scenarios.",
)
parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers")
parser.add_argument(
"--simulate_dyn_prompt",
Expand Down Expand Up @@ -363,6 +368,12 @@ def setup_parser(parser):
logger.warning(
"`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag."
)
if args.quant_config != "" and args.kv_cache_on_host:
logger.warning("`--kv_cache_on_host` is not supported with FP8 quantization. Setting this flag to False.")
args.kv_cache_on_host = False
if args.kv_cache_on_host and args.use_hpu_graphs:
logger.warning("`--kv_cache_on_host` is not supported with HPU graphs. Setting this flag to False.")
args.kv_cache_on_host = False
return args


Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.valid_sequence_lengths = None
generation_config.kv_cache_on_host = args.kv_cache_on_host

return generation_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
kv_cache_on_host (`bool`, *optional*):
Whether to store key/value cache on host (CPU).
"""

def __init__(self, **kwargs):
Expand All @@ -56,3 +58,4 @@ def __init__(self, **kwargs):
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None)
self.kv_cache_on_host = kwargs.get("kv_cache_on_host", False)
20 changes: 17 additions & 3 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,10 @@ def generate(
), "please set bucket_internal along with reuse_cache and bucket_size"
else:
assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"
if generation_config.kv_cache_on_host:
assert self.config.model_type in [
"llama",
], "kv_cache_on_host only supported by Llama at the moment"

if self.config.model_type == "gemma2":
generation_config.cache_implementation = None
Expand Down Expand Up @@ -1277,9 +1281,19 @@ def generate(
if generation_config.use_cache and generation_config.reuse_cache:
bs, _ = input_ids.shape
if not is_greedy_or_beam_and_bucket:
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens
)
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
logger.info("Allocating KV cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams,
calculated_max_length,
token_idx + num_virtual_tokens,
device=cache_device,
)
else:
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From line 1096 to 1107, I would like to suggest to change like this.

if not is_greedy_or_beam_and_bucket:
cache_device = "hpu"
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
cache_device = "cpu"
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
device=cache_device
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have updated it in 74e94ff. However, I can not remove the else line because I only modified the modeling_llama.py for this experimental feature.

if generation_config.use_cache:
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens
Expand Down
190 changes: 122 additions & 68 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,36 @@ def gaudi_llama_repeat_kv(
return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
def gaudi_llama_repeat_kv_cpu(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
PyTorch SDPA CPU (flash-atten) kernel does not support GQA/MQA for now.
So, expand k and v to num_query_heads.
"""
query_states = query_states.to("cpu")
key_states = key_states.to("cpu")
value_states = value_states.to("cpu")
if attention_mask is not None:
attention_mask = attention_mask.to("cpu")

batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask

key_states = key_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, kv_len, head_dim)
value_states = value_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, kv_len, head_dim)
key_states = key_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim)
value_states = value_states.reshape(batch, num_key_value_heads * n_rep, kv_len, head_dim)

return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
Expand Down Expand Up @@ -453,9 +482,8 @@ def get_k_proj_weight_dtype(self):
return self.k_proj.scales.dtype
return self.k_proj.weight.dtype

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.get_k_proj_weight().device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
Expand Down Expand Up @@ -642,44 +670,37 @@ def pre_attn_forward(
kv_seq_len = key_states.shape[-2]
else:
past_key_value = None
fused_scaled_dot_product_attention = GaudiDistributedAttention(
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed
)
if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
"None",
False,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:

kv_cache_on_host = key_states.device == torch.device("cpu") and value_states.device == torch.device("cpu")
# CPU SDPA fot next token
if kv_cache_on_host and q_len == 1 and not self.training:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# dispatch to flash attention implementation
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor,
)
attn_output = attn_output.to("hpu", non_blocking=True)

else:
if kv_cache_on_host:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain what's the case switching kv_cache device? I thought line 656 is the case only when line 658.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this pr, we make kv cache store on cpu and do cpu sdpa only when generating the next token. The first token or prefill stage is performed on HPU due to its powerful computation ability under long-context scenario (long prompt in most cases). The full pipeline diagram shows on the pr description.
So line 658 tells the machine it can do pytorch-cpu sdpa (flash-attn) only when kv_cache_on_host & in next-token generation & inference stage. Otherwise, it will transfer the kv-cache to hpu device if need for its original operations.
Please let me know if you need more explanations or have some suggestions. Thanks.

key_states = key_states.to("hpu", non_blocking=True)
value_states = value_states.to("hpu", non_blocking=True)

fused_scaled_dot_product_attention = GaudiDistributedAttention(
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed
)
if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -688,35 +709,68 @@ def pre_attn_forward(
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
"None",
False,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
attn_output = fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)

attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor
attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = torch.nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
attn_output = self.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -862,8 +916,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int):
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.self_attn.reorder_kv_cache(beam_idx)
Expand Down Expand Up @@ -1043,9 +1097,9 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
Expand Down Expand Up @@ -1281,8 +1335,8 @@ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy
config.parallel_strategy = parallel_strategy
super().__init__(config)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, device="hpu"):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, device=device)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
cur = cur.to(prev.device)
if idx is not None:
idx = idx.to(prev.device)
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
Expand Down