Skip to content

Commit

Permalink
refact code and add README
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Sep 11, 2024
1 parent 7168b27 commit ff3c54f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 42 deletions.
21 changes: 21 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,27 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

### Store KV Cache on CPU
Keeping key/value cache on CPU (host) side can decrease hpu vram in spite of it may damage generation latency. It's a practical solution in long context serving scenario with a large LLM on single card.

You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) will be automatically used to generate next token for saving data transfer time. First token is not be affected.

For exmaple:
```bash
python run_generation.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--use_kv_cache \
--bf16 \
--attn_softmax_bf16 \
--max_new_tokens 128 \
--reuse_cache \
--do_sample \
--prompt "Here is my prompt"
--kv_cache_on_host
```

> [!NOTE]
> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type.
## Language Model Evaluation Harness

Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def setup_parser(parser):
"`--kv_cache_on_host` is not supported with FP8 quantization. Set this flag to False."
)
args.kv_cache_on_host = False
if args.kv_cache_on_host and args.use_hpu_graphs:
logger.warning(
"`--kv_cache_on_host` is not supported with HPU graphs. Set this flag to False."
)
args.kv_cache_on_host = False
return args


Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def generate(
bs, _ = input_ids.shape
cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu"
if not is_greedy_or_beam_and_bucket:
if self.config.model_type in ["llama"]:
if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]:
print("Allocate KV Cache on CPU...")
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens,
Expand Down
83 changes: 42 additions & 41 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def gaudi_llama_repeat_kv(
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
query_states = query_states.to("hpu")
key_states = key_states.to("hpu")
value_states = value_states.to("hpu")
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask
Expand Down Expand Up @@ -656,49 +653,53 @@ def pre_attn_forward(
else:
past_key_value = None

bool kv_cache_on_host = (key_states.device() == "cpu" and value_states.device() == "cpu")
if use_flash_attention and FusedSDPA is not None and not kv_cache_on_host:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
kv_cache_on_host = (key_states.device == "cpu" and value_states.device == "cpu")
# CPU SDPA fot next token
if kv_cache_on_host and q_len == 1 and not self.training:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# dispatch to flash attention implementation
attn_output = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor)
attn_output = attn_output.to("hpu")

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
else:
if kv_cache_on_host:
key_states = key_states.to("hpu")
value_states = value_states.to("hpu")
if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

else:
if q_len == 1 and kv_cache_on_host:
# CPU SDPA fot next token
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv_cpu(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
)
# pytorch https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# dispatch to flash attention implementation
attn_output = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=self.norm_factor)
attn_output = attn_output.to("hpu")
else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand Down

0 comments on commit ff3c54f

Please sign in to comment.