From c935a3ddbf0308377702ac1317ac759162da6774 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:15:38 +0800 Subject: [PATCH] Set use_cache in ipex model tests (#649) --- optimum/intel/ipex/modeling_base.py | 4 ++++ tests/ipex/test_modeling.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 0664a8e6ac..8a7a4f2028 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -89,6 +89,10 @@ def ipex_jit_trace(model, task, use_cache): model.config.return_dict = False + if "past_key_values" in sample_inputs and use_cache: + # Make sure the model will output past_key_values in generation tasks + model.config.use_cache = True + model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) # Disable repack while jit tracing to reduce the memory ipex._C.disable_jit_linear_repack() diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index af59900424..94a5ca9e16 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -253,7 +253,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False) + model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) model.config.encoder_no_repeat_ngram_size = 0 model.to("cpu") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)