diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 360d1fe07a..e593b42801 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -547,19 +547,26 @@ You can add `--kv_cache_on_host` arg to enable it. [Pytorch SDPA operator](https For exmaple: ```bash python run_generation.py \ ---model_name_or_path meta-llama/Llama-2-7b-hf \ +--model_name_or_path 01-ai/Yi-34B-Chat \ --use_kv_cache \ --bf16 \ --attn_softmax_bf16 \ ---max_new_tokens 128 \ --reuse_cache \ --do_sample \ ---prompt "Here is my prompt" +--dataset_name emozilla/pg19-test \ +--batch_size 1 \ +--max_input_tokens 11200 \ +--column_name "text" \ +--dataset_max_samples 1 \ +--warmup 0 \ +--n_iterations 1 \ +--max_new_tokens 5000 \ --kv_cache_on_host ``` > [!NOTE] -> `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. +> 1. `--kv_cache_on_host` only supports llama model for now. And it can not work with `--use_hpu_grapgs` and FP8 data type. +> 2. Try to use it when you only meet HPU workspace allocation error (`OOM`) since it will increase latency. ## Language Model Evaluation Harness diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 6f1dcdd886..ee26d54dbc 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1093,10 +1093,10 @@ def generate( calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens + num_virtual_tokens if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape - cache_device = "cpu" if generation_config.kv_cache_on_host else "hpu" if not is_greedy_or_beam_and_bucket: if generation_config.kv_cache_on_host and self.config.model_type in ["llama"]: print("Allocate KV Cache on CPU...") + cache_device = "cpu" unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens, device=cache_device